From 9e53c823b8ce7c04a310bde621197001753a63af Mon Sep 17 00:00:00 2001 From: imaginary-person Date: Thu, 22 Jul 2021 08:49:55 -0700 Subject: [PATCH] Add AVX512 support in ATen & remove AVX support (#61903) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61903 ### Remaining Tasks - [ ] Collate results of benchmarks on two Intel Xeon machines (with & without CUDA, to check if CPU throttling causes issues with GPUs) - make graphs, including Roofline model plots (Intel Advisor can't make them with libgomp, though, but with Intel OpenMP). ### Summary 1. This draft PR produces binaries with with 3 types of ATen kernels - default, AVX2, AVX512 . Using the environment variable `ATEN_AVX512_256=TRUE` also results in 3 types of kernels, but the compiler can use 32 ymm registers for AVX2, instead of the default 16. ATen kernels for `CPU_CAPABILITY_AVX` have been removed. 2. `nansum` is not using AVX512 kernel right now, as it has poorer accuracy for Float16, than does AVX2 or DEFAULT, whose respective accuracies aren't very good either (#59415). It was more convenient to disable AVX512 dispatch for all dtypes of `nansum` for now. 3. On Windows , ATen Quantized AVX512 kernels are not being used, as quantization tests are flaky. If `--continue-through-failure` is used, then `test_compare_model_outputs_functional_static` fails. But if this test is skipped, `test_compare_model_outputs_conv_static` fails. If both these tests are skipped, then a third one fails. These are hard to debug right now due to not having access to a Windows machine with AVX512 support, so it was more convenient to disable AVX512 dispatch of all ATen Quantized kernels on Windows for now. 4. One test is currently being skipped - [test_lstm` in `quantization.bc](https://github.com/pytorch/pytorch/issues/59098) - It fails only on Cascade Lake machines, irrespective of the `ATEN_CPU_CAPABILITY` used, because FBGEMM uses `AVX512_VNNI` on machines that support it. The value of `reduce_range` should be used as `False` on such machines. The list of the changes is at https://gist.github.com/imaginary-person/4b4fda660534f0493bf9573d511a878d. Credits to ezyang for proposing `AVX512_256` - these use AVX2 intrinsics but benefit from 32 registers, instead of the 16 ymm registers that AVX2 uses. Credits to limo1996 for the initial proposal, and for optimizing `hsub_pd` & `hadd_pd`, which didn't have direct AVX512 equivalents, and are being used in some kernels. He also refactored `vec/functional.h` to remove duplicated code. Credits to quickwritereader for helping fix 4 failing complex multiplication & division tests. ### Testing 1. `vec_test_all_types` was modified to test basic AVX512 support, as tests already existed for AVX2. Only one test had to be modified, as it was hardcoded for AVX2. 2. `pytorch_linux_bionic_py3_8_gcc9_coverage_test1` & `pytorch_linux_bionic_py3_8_gcc9_coverage_test2` are now using `linux.2xlarge` instances, as they support AVX512. They were used for testing AVX512 kernels, as AVX512 kernels are being used by default in both of the CI checks. Windows CI checks had already been using machines with AVX512 support. ### Would the downclocking caused by AVX512 pose an issue? I think it's important to note that AVX2 causes downclocking as well, and the additional downclocking caused by AVX512 may not hamper performance on some Skylake machines & beyond, because of the double vector-size. I think that [this post with verifiable references is a must-read](https://community.intel.com/t5/Software-Tuning-Performance/Unexpected-power-vs-cores-profile-for-MKL-kernels-on-modern-Xeon/m-p/1133869/highlight/true#M6450). Also, AVX512 would _probably not_ hurt performance on a high-end machine, [but measurements are recommended](https://lemire.me/blog/2018/09/07/avx-512-when-and-how-to-use-these-new-instructions/). In case it does, `ATEN_AVX512_256=TRUE` can be used for building PyTorch, as AVX2 can then use 32 ymm registers instead of the default 16. [FBGEMM uses `AVX512_256` only on Xeon D processors](https://github.com/pytorch/FBGEMM/pull/209), which are said to have poor AVX512 performance. This [official data](https://www.intel.com/content/dam/www/public/us/en/documents/specification-updates/xeon-scalable-spec-update.pdf) is for the Intel Skylake family, and the first link helps understand its significance. Cascade Lake & Ice Lake SP Xeon processors are said to be even better when it comes to AVX512 performance. Here is the corresponding data for [Cascade Lake](https://cdrdv2.intel.com/v1/dl/getContent/338848) - ![CASCADE LAKE AVX2](https://user-images.githubusercontent.com/76181208/120666172-ffec3f80-c451-11eb-8ea1-8933ccc12a1b.PNG) ![CASCADE LAKE AVX512](https://user-images.githubusercontent.com/76181208/120666190-04b0f380-c452-11eb-9faa-38d233c874c8.PNG) The corresponding data isn't publicly available for Intel Xeon SP 3rd gen (Ice Lake SP), but [Intel mentioned that the 3rd gen has frequency improvements pertaining to AVX512](https://newsroom.intel.com/wp-content/uploads/sites/11/2021/04/3rd-Gen-Intel-Xeon-Scalable-Platform-Press-Presentation-281884.pdf). Ice Lake SP machines also have 48 KB L1D caches, so that's another reason for AVX512 performance to be better on them. ### Is PyTorch always faster with AVX512? No, but then PyTorch is not always faster with AVX2 either. Please refer to #60202. The benefit from vectorization is apparent with with small tensors that fit in caches or in kernels that are more compute heavy. For instance, AVX512 or AVX2 would yield no benefit for adding two 64 MB tensors, but adding two 1 MB tensors would do well with AVX2, and even more so with AVX512. It seems that memory-bound computations, such as adding two 64 MB tensors can be slow with vectorization (depending upon the number of threads used), as the effects of downclocking can then be observed. Original pull request: https://github.com/pytorch/pytorch/pull/56992 Reviewed By: soulitzer Differential Revision: D29266289 Pulled By: ezyang fbshipit-source-id: 2d5e8d1c2307252f22423bbc14f136c67c3e6184 --- .jenkins/pytorch/test.sh | 4 +- aten.bzl | 3 +- aten/src/ATen/CMakeLists.txt | 2 +- aten/src/ATen/Version.cpp | 6 +- aten/src/ATen/cpu/FlushDenormal.cpp | 3 +- aten/src/ATen/cpu/vec/functional.h | 7 +- .../cpu/vec/{vec256 => }/functional_base.h | 2 +- .../vec/{vec256 => }/functional_bfloat16.h | 18 +- .../ATen/cpu/vec/{vec256 => }/intrinsics.h | 7 +- aten/src/ATen/cpu/vec/vec.h | 4 + aten/src/ATen/cpu/vec/vec256/functional.h | 6 - aten/src/ATen/cpu/vec/vec256/vec256.h | 35 +- .../src/ATen/cpu/vec/vec256/vec256_bfloat16.h | 39 +- .../cpu/vec/vec256/vec256_complex_double.h | 17 +- .../cpu/vec/vec256/vec256_complex_float.h | 16 +- aten/src/ATen/cpu/vec/vec256/vec256_double.h | 21 +- aten/src/ATen/cpu/vec/vec256/vec256_float.h | 23 +- .../ATen/cpu/vec/vec256/vec256_float_neon.h | 50 +- aten/src/ATen/cpu/vec/vec256/vec256_int.h | 26 +- aten/src/ATen/cpu/vec/vec256/vec256_qint.h | 330 +---- .../cpu/vec/vec256/vsx/vec256_common_vsx.h | 4 +- .../vec256/vsx/vec256_complex_double_vsx.h | 16 +- .../vec/vec256/vsx/vec256_complex_float_vsx.h | 16 +- .../cpu/vec/vec256/vsx/vec256_double_vsx.h | 8 +- .../cpu/vec/vec256/vsx/vec256_float_vsx.h | 8 +- .../cpu/vec/vec256/vsx/vec256_int16_vsx.h | 8 +- .../cpu/vec/vec256/vsx/vec256_int32_vsx.h | 8 +- .../cpu/vec/vec256/vsx/vec256_int64_vsx.h | 8 +- .../cpu/vec/vec256/vsx/vec256_qint32_vsx.h | 8 +- .../cpu/vec/vec256/vsx/vec256_qint8_vsx.h | 8 +- .../cpu/vec/vec256/vsx/vec256_quint8_vsx.h | 8 +- .../src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h | 2 +- aten/src/ATen/cpu/vec/vec512/vec512.h | 195 +++ .../src/ATen/cpu/vec/vec512/vec512_bfloat16.h | 879 ++++++++++++ .../cpu/vec/vec512/vec512_complex_double.h | 526 ++++++++ .../cpu/vec/vec512/vec512_complex_float.h | 1030 ++++++++++++++ aten/src/ATen/cpu/vec/vec512/vec512_double.h | 454 +++++++ aten/src/ATen/cpu/vec/vec512/vec512_float.h | 469 +++++++ aten/src/ATen/cpu/vec/vec512/vec512_int.h | 1173 ++++++++++++++++ aten/src/ATen/cpu/vec/vec512/vec512_qint.h | 1195 +++++++++++++++++ .../vec/{vec256/vec256_base.h => vec_base.h} | 176 ++- aten/src/ATen/cpu/vml.h | 57 +- .../ATen/native/BatchLinearAlgebraKernel.cpp | 42 +- aten/src/ATen/native/DispatchStub.cpp | 45 +- aten/src/ATen/native/DispatchStub.h | 36 +- aten/src/ATen/native/SegmentReduce.cpp | 6 +- aten/src/ATen/native/cpu/README.md | 16 +- aten/src/ATen/native/cpu/Reduce.h | 13 +- aten/src/ATen/native/cpu/SoftMaxKernel.cpp | 8 +- aten/src/ATen/native/cpu/SumKernel.cpp | 9 +- aten/src/ATen/native/cpu/UnaryOpsKernel.cpp | 7 + aten/src/ATen/native/cpu/avx_mathfun.h | 207 +-- aten/src/ATen/native/mkl/SpectralOps.cpp | 2 +- .../cpu/kernels/QuantizedOpKernels.cpp | 320 ++++- aten/src/ATen/test/vec_test_all_types.cpp | 21 +- aten/src/ATen/test/vec_test_all_types.h | 11 +- caffe2/CMakeLists.txt | 29 +- cmake/Codegen.cmake | 44 +- cmake/Modules/FindAVX.cmake | 21 + setup.py | 7 + test/cpp/api/dispatch.cpp | 10 +- .../bc/test_backward_compatibility.py | 5 +- torch/testing/_internal/common_utils.py | 9 + 63 files changed, 6772 insertions(+), 971 deletions(-) rename aten/src/ATen/cpu/vec/{vec256 => }/functional_base.h (99%) rename aten/src/ATen/cpu/vec/{vec256 => }/functional_bfloat16.h (97%) rename aten/src/ATen/cpu/vec/{vec256 => }/intrinsics.h (86%) delete mode 100644 aten/src/ATen/cpu/vec/vec256/functional.h create mode 100644 aten/src/ATen/cpu/vec/vec512/vec512.h create mode 100644 aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h create mode 100644 aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h create mode 100644 aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h create mode 100644 aten/src/ATen/cpu/vec/vec512/vec512_double.h create mode 100644 aten/src/ATen/cpu/vec/vec512/vec512_float.h create mode 100644 aten/src/ATen/cpu/vec/vec512/vec512_int.h create mode 100644 aten/src/ATen/cpu/vec/vec512/vec512_qint.h rename aten/src/ATen/cpu/vec/{vec256/vec256_base.h => vec_base.h} (84%) diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 106dd098dcc..4f2fe2615e5 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -132,7 +132,9 @@ fi if [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX-* || $TEST_CONFIG == 'nogpu_NO_AVX' ]]; then export ATEN_CPU_CAPABILITY=default elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX2-* || $TEST_CONFIG == 'nogpu_NO_AVX2' ]]; then - export ATEN_CPU_CAPABILITY=avx + export ATEN_CPU_CAPABILITY=default +elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX512-* || $TEST_CONFIG == 'nogpu_NO_AVX512' ]]; then + export ATEN_CPU_CAPABILITY=avx2 fi if [ -n "$IN_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]; then diff --git a/aten.bzl b/aten.bzl index 6bce36ca904..c2fcee7323d 100644 --- a/aten.bzl +++ b/aten.bzl @@ -1,9 +1,8 @@ load("@rules_cc//cc:defs.bzl", "cc_library") -CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX", "AVX2"] +CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"] CAPABILITY_COMPILER_FLAGS = { "AVX2": ["-mavx2", "-mfma"], - "AVX": ["-mavx"], "DEFAULT": [], } diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index cf4c192d5de..3e160a20102 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -50,7 +50,7 @@ if(NOT BUILD_LITE_INTERPRETER) endif() EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS}) -file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h") +file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h") file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp") file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh") file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp") diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index 43d81cbdbe3..750c90bb4c5 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -108,12 +108,12 @@ std::string used_cpu_capability() { case native::CPUCapability::DEFAULT: ss << "NO AVX"; break; - case native::CPUCapability::AVX: - ss << "AVX"; - break; case native::CPUCapability::AVX2: ss << "AVX2"; break; + case native::CPUCapability::AVX512: + ss << "AVX512"; + break; #endif default: break; diff --git a/aten/src/ATen/cpu/FlushDenormal.cpp b/aten/src/ATen/cpu/FlushDenormal.cpp index 7c7df405be5..c1d330f6a74 100644 --- a/aten/src/ATen/cpu/FlushDenormal.cpp +++ b/aten/src/ATen/cpu/FlushDenormal.cpp @@ -1,6 +1,5 @@ #include - -#include +#include #include namespace at { namespace cpu { diff --git a/aten/src/ATen/cpu/vec/functional.h b/aten/src/ATen/cpu/vec/functional.h index c9a9c443a16..210ae9e9e88 100644 --- a/aten/src/ATen/cpu/vec/functional.h +++ b/aten/src/ATen/cpu/vec/functional.h @@ -1 +1,6 @@ -#include +#pragma once + +#include +#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX) +#include +#endif diff --git a/aten/src/ATen/cpu/vec/vec256/functional_base.h b/aten/src/ATen/cpu/vec/functional_base.h similarity index 99% rename from aten/src/ATen/cpu/vec/vec256/functional_base.h rename to aten/src/ATen/cpu/vec/functional_base.h index 519f1008788..7bd04e637c7 100644 --- a/aten/src/ATen/cpu/vec/vec256/functional_base.h +++ b/aten/src/ATen/cpu/vec/functional_base.h @@ -3,7 +3,7 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include +#include namespace at { namespace vec { diff --git a/aten/src/ATen/cpu/vec/vec256/functional_bfloat16.h b/aten/src/ATen/cpu/vec/functional_bfloat16.h similarity index 97% rename from aten/src/ATen/cpu/vec/vec256/functional_bfloat16.h rename to aten/src/ATen/cpu/vec/functional_bfloat16.h index 442c587a6c2..9efa7004090 100644 --- a/aten/src/ATen/cpu/vec/vec256/functional_bfloat16.h +++ b/aten/src/ATen/cpu/vec/functional_bfloat16.h @@ -3,7 +3,7 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include +#include namespace at { namespace vec { @@ -15,26 +15,26 @@ template <> struct VecScalarType { using type = float; }; template using vec_scalar_t = typename VecScalarType::type; -// Note that we already have specializes member of Vectorized for BFloat16 -// so the following function would run smoothly: +// Note that we already have specialized member of Vectorized for BFloat16 +// so the following functions would run smoothly: // using Vec = Vectorized; // Vec one = Vec(BFloat16(1)); // vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); // -// Why we still need to specializes "funtional"? +// Then why we still need to specialize "funtional"? // If we do specialization at Vectorized<> level, the above example would need 3 pairs of -// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/". +// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/". // If we do specialization at vec::map<>() level, we have only 1 pair of conversion -// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only. +// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only. // -// The following BFloat16 functionalities will only do data type conversion for input -// and output vector (reduce functionalities will only convert the final scalar back to bf16). +// The following BFloat16 functionality will only do data type conversion for input +// and output vector (reduce functionality will only convert the final scalar back to bf16). // Compared to Vectorized<> specialization, // 1. better performance since we have less data type conversion; // 2. less rounding error since immediate results are kept in fp32; // 3. accumulation done on data type of fp32. // -// If you plan to extend this file, make sure add unit test at +// If you plan to extend this file, please ensure adding unit tests at // aten/src/ATen/test/vec_test_all_types.cpp // template diff --git a/aten/src/ATen/cpu/vec/vec256/intrinsics.h b/aten/src/ATen/cpu/vec/intrinsics.h similarity index 86% rename from aten/src/ATen/cpu/vec/vec256/intrinsics.h rename to aten/src/ATen/cpu/vec/intrinsics.h index 5ac4d484ccd..a6a73e232e1 100644 --- a/aten/src/ATen/cpu/vec/vec256/intrinsics.h +++ b/aten/src/ATen/cpu/vec/intrinsics.h @@ -1,6 +1,6 @@ #pragma once -#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__)) -/* Clang-compatible compiler, targeting x86/x86-64 */ +#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) +/* GCC or clang-compatible compiler, targeting x86/x86-64 */ #include #elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* Clang-compatible compiler, targeting arm neon */ @@ -14,9 +14,6 @@ #define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) #define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) #endif -#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) -/* GCC-compatible compiler, targeting x86/x86-64 */ -#include #elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* GCC-compatible compiler, targeting ARM with NEON */ #include diff --git a/aten/src/ATen/cpu/vec/vec.h b/aten/src/ATen/cpu/vec/vec.h index 5a041f2df70..24b8818d2a8 100644 --- a/aten/src/ATen/cpu/vec/vec.h +++ b/aten/src/ATen/cpu/vec/vec.h @@ -1 +1,5 @@ +#if defined(CPU_CAPABILITY_AVX512) +#include +#else #include +#endif diff --git a/aten/src/ATen/cpu/vec/vec256/functional.h b/aten/src/ATen/cpu/vec/vec256/functional.h deleted file mode 100644 index e1ddc3c9cc7..00000000000 --- a/aten/src/ATen/cpu/vec/vec256/functional.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -#include -#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX) -#include -#endif diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index 4200e6729af..0d13458bc4c 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -3,9 +3,9 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include +#include -#include +#include #if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX) #include #include @@ -68,9 +68,9 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { } -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template<> inline Vectorized cast(const Vectorized& src) { @@ -82,29 +82,6 @@ inline Vectorized cast(const Vectorized& src) { return _mm256_castps_pd(src); } -#if defined(CPU_CAPABILITY_AVX2) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -#define DEFINE_FLOAT_INT_CAST(int_t, float_t, float_ch) \ -template<> \ -inline Vectorized cast(const Vectorized& src) { \ - return _mm256_castp ## float_ch ## _si256(src); \ -} \ -template<> \ -inline Vectorized cast(const Vectorized& src) { \ - return _mm256_castsi256_p ## float_ch (src); \ -} - -DEFINE_FLOAT_INT_CAST(int64_t, double, d) -DEFINE_FLOAT_INT_CAST(int32_t, double, d) -DEFINE_FLOAT_INT_CAST(int16_t, double, d) -DEFINE_FLOAT_INT_CAST(int64_t, float, s) -DEFINE_FLOAT_INT_CAST(int32_t, float, s) -DEFINE_FLOAT_INT_CAST(int16_t, float, s) - -#undef DEFINE_FLOAT_INT_CAST - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template @@ -243,8 +220,6 @@ inline deinterleave2(const Vectorized& a, const Vectorized& _mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart } -#endif // defined(CPU_CAPABILITY_AVX2) - -#endif // (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) }}} diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h index 7d61d8a9d38..82a8200ce2c 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h @@ -3,8 +3,8 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include +#include +#include #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif @@ -100,7 +100,7 @@ public: return _mm256_loadu_si256(reinterpret_cast(ptr)); } static Vectorized loadu(const void* ptr, int16_t count) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); return loadu(tmp_values); } @@ -108,14 +108,14 @@ public: if (count == size()) { _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); } } template static Vectorized blend(const Vectorized& a, const Vectorized& b) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; a.store(tmp_values); if (mask & 0x01) tmp_values[0] = _mm256_extract_epi16(b.values, 0); @@ -280,7 +280,7 @@ public: Vectorized erfinv() const { __m256 lo, hi; cvtbf16_fp32(values, lo, hi); - __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmp1), lo); _mm256_storeu_ps(reinterpret_cast(tmp2), hi); for (int64_t i = 0; i < size() / 2; i++) { @@ -318,7 +318,7 @@ public: Vectorized i0() const { __m256 lo, hi; cvtbf16_fp32(values, lo, hi); - __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmp1), lo); _mm256_storeu_ps(reinterpret_cast(tmp2), hi); for (int64_t i = 0; i < size() / 2; i++) { @@ -333,7 +333,7 @@ public: __m256 lo, hi; cvtbf16_fp32(values, lo, hi); constexpr auto sz = size(); - __at_align32__ float tmp1[sz / 2], tmp2[sz / 2]; + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; _mm256_storeu_ps(reinterpret_cast(tmp1), lo); _mm256_storeu_ps(reinterpret_cast(tmp2), hi); @@ -350,10 +350,10 @@ public: __m256 xlo, xhi; cvtbf16_fp32(values, lo, hi); cvtbf16_fp32(x.values, xlo, xhi); - __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmp1), lo); _mm256_storeu_ps(reinterpret_cast(tmp2), hi); - __at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2]; + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); for (int64_t i = 0; i < size() / 2; ++i) { @@ -370,10 +370,10 @@ public: __m256 xlo, xhi; cvtbf16_fp32(values, lo, hi); cvtbf16_fp32(x.values, xlo, xhi); - __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmp1), lo); _mm256_storeu_ps(reinterpret_cast(tmp2), hi); - __at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2]; + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); for (int64_t i = 0; i < size() / 2; ++i) { @@ -717,12 +717,13 @@ inline Vectorized convert_float_bfloat16(const Vectorized& a, c return cvtfp32_bf16(__m256(a), __m256(b)); } -#else //defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) + +#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) inline std::tuple, Vectorized> convert_bfloat16_float(const Vectorized& a) { constexpr int64_t K = Vectorized::size(); - __at_align32__ float arr[K]; - __at_align32__ BFloat16 arr2[K]; + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; a.store(arr2); convert(arr2, arr, K); return std::make_tuple( @@ -732,15 +733,15 @@ inline std::tuple, Vectorized> convert_bfloat16_float(c inline Vectorized convert_float_bfloat16(const Vectorized& a, const Vectorized& b) { constexpr int64_t K = Vectorized::size(); - __at_align32__ float arr[K]; - __at_align32__ BFloat16 arr2[K]; + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; a.store(arr); b.store(arr + Vectorized::size()); convert(arr, arr2, K); return Vectorized::loadu(arr2); } -#endif +#endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out) { @@ -759,7 +760,7 @@ void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out1, Vec } #else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out) { - __at_align32__ float values[Vectorized::size()]; + __at_align__ float values[Vectorized::size()]; for (int k = 0; k < Vectorized::size(); ++k) { values[k] = data[k]; } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h index f96aea6e09e..40276ba8365 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h @@ -4,9 +4,10 @@ // See Note [Do not compile initializers with AVX] #include -#include -#include -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#include +#include + +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif @@ -15,7 +16,7 @@ namespace vec { // See Note [Acceptable use of anonymous namespace in header] namespace { -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template <> class Vectorized> { private: @@ -81,7 +82,7 @@ public: if (count == size()) return _mm256_loadu_pd(reinterpret_cast(ptr)); - __at_align32__ double tmp_values[2*size()]; + __at_align__ double tmp_values[2*size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -106,7 +107,7 @@ public: const c10::complex& operator[](int idx) const = delete; c10::complex& operator[](int idx) = delete; Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { - __at_align32__ c10::complex tmp[size()]; + __at_align__ c10::complex tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -288,8 +289,8 @@ public: return sqrt().reciprocal(); } Vectorized> pow(const Vectorized> &exp) const { - __at_align32__ c10::complex x_tmp[size()]; - __at_align32__ c10::complex y_tmp[size()]; + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (int i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h index 5494828b565..f4019632002 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h @@ -4,9 +4,9 @@ // See Note [Do not compile initializers with AVX] #include -#include -#include -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#include +#include +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif @@ -15,7 +15,7 @@ namespace vec { // See Note [Acceptable use of anonymous namespace in header] namespace { -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template <> class Vectorized> { private: @@ -117,7 +117,7 @@ public: if (count == size()) return _mm256_loadu_ps(reinterpret_cast(ptr)); - __at_align32__ float tmp_values[2*size()]; + __at_align__ float tmp_values[2*size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -142,7 +142,7 @@ public: const c10::complex& operator[](int idx) const = delete; c10::complex& operator[](int idx) = delete; Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { - __at_align32__ c10::complex tmp[size()]; + __at_align__ c10::complex tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -323,8 +323,8 @@ public: return sqrt().reciprocal(); } Vectorized> pow(const Vectorized> &exp) const { - __at_align32__ c10::complex x_tmp[size()]; - __at_align32__ c10::complex y_tmp[size()]; + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (int i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_double.h index 1c575b9a28c..f92f44e562a 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_double.h @@ -3,9 +3,9 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#include +#include +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif @@ -14,7 +14,8 @@ namespace vec { // See Note [Acceptable use of anonymous namespace in header] namespace { -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) + +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template <> class Vectorized { private: @@ -67,7 +68,7 @@ public: return _mm256_loadu_pd(reinterpret_cast(ptr)); - __at_align32__ double tmp_values[size()]; + __at_align__ double tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -100,7 +101,7 @@ public: return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q); } Vectorized map(double (*const f)(double)) const { - __at_align32__ double tmp[size()]; + __at_align__ double tmp[size()]; store(tmp); for (int64_t i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -175,8 +176,8 @@ public: return map(calc_i0e); } Vectorized igamma(const Vectorized &x) const { - __at_align32__ double tmp[size()]; - __at_align32__ double tmp_x[size()]; + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { @@ -185,8 +186,8 @@ public: return loadu(tmp); } Vectorized igammac(const Vectorized &x) const { - __at_align32__ double tmp[size()]; - __at_align32__ double tmp_x[size()]; + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h index 1f4c3f63477..deb95429843 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h @@ -3,9 +3,9 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#include +#include +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif @@ -14,7 +14,7 @@ namespace vec { // See Note [Acceptable use of anonymous namespace in header] namespace { -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template <> class Vectorized { private: @@ -76,7 +76,7 @@ public: static Vectorized loadu(const void* ptr, int64_t count = size()) { if (count == size()) return _mm256_loadu_ps(reinterpret_cast(ptr)); - __at_align32__ float tmp_values[size()]; + __at_align__ float tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -107,7 +107,7 @@ public: return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); } Vectorized map(float (*const f)(float)) const { - __at_align32__ float tmp[size()]; + __at_align__ float tmp[size()]; store(tmp); for (int64_t i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -213,8 +213,8 @@ public: return map(calc_i0e); } Vectorized igamma(const Vectorized &x) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_x[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { @@ -223,8 +223,8 @@ public: return loadu(tmp); } Vectorized igammac(const Vectorized &x) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_x[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { @@ -412,12 +412,11 @@ inline void convert(const float* src, float* dst, int64_t n) { } } -#ifdef CPU_CAPABILITY_AVX2 + template <> Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { return _mm256_fmadd_ps(a, b, c); } -#endif #endif diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h index b39d808a13a..2aac442d212 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h @@ -3,8 +3,8 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include +#include +#include // Sleef offers vectorized versions of some transcedentals // such as sin, cos, tan etc.. // However for now opting for STL, since we are not building @@ -220,7 +220,7 @@ public: return res; } else { - __at_align32__ float tmp_values[size()]; + __at_align__ float tmp_values[size()]; for (auto i = 0; i < size(); ++i) { tmp_values[i] = 0.0; } @@ -261,19 +261,19 @@ public: // Once we specialize that implementation for ARM // this should be removed. TODO (kimishpatel) float operator[](int idx) const { - __at_align32__ float tmp[size()]; + __at_align__ float tmp[size()]; store(tmp); return tmp[idx]; } float operator[](int idx) { - __at_align32__ float tmp[size()]; + __at_align__ float tmp[size()]; store(tmp); return tmp[idx]; } // For boolean version where we want to if any 1/all zero // etc. can be done faster in a different way. int zero_mask() const { - __at_align32__ float tmp[size()]; + __at_align__ float tmp[size()]; store(tmp); int mask = 0; for (int i = 0; i < size(); ++ i) { @@ -284,8 +284,8 @@ public: return mask; } Vectorized isnan() const { - __at_align32__ float tmp[size()]; - __at_align32__ float res[size()]; + __at_align__ float tmp[size()]; + __at_align__ float res[size()]; store(tmp); for (int i = 0; i < size(); i++) { if (_isnan(tmp[i])) { @@ -297,7 +297,7 @@ public: return loadu(res); }; Vectorized map(float (*const f)(float)) const { - __at_align32__ float tmp[size()]; + __at_align__ float tmp[size()]; store(tmp); for (int64_t i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -332,8 +332,8 @@ public: return map(std::atan); } Vectorized atan2(const Vectorized &exp) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_exp[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_exp[size()]; store(tmp); exp.store(tmp_exp); for (int64_t i = 0; i < size(); i++) { @@ -342,8 +342,8 @@ public: return loadu(tmp); } Vectorized copysign(const Vectorized &sign) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_sign[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_sign[size()]; store(tmp); sign.store(tmp_sign); for (size_type i = 0; i < size(); i++) { @@ -367,8 +367,8 @@ public: return map(std::expm1); } Vectorized fmod(const Vectorized& q) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_q[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_q[size()]; store(tmp); q.store(tmp_q); for (int64_t i = 0; i < size(); i++) { @@ -377,8 +377,8 @@ public: return loadu(tmp); } Vectorized hypot(const Vectorized &b) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_b[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); i++) { @@ -393,8 +393,8 @@ public: return map(calc_i0e); } Vectorized igamma(const Vectorized &x) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_x[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { @@ -403,8 +403,8 @@ public: return loadu(tmp); } Vectorized igammac(const Vectorized &x) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_x[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { @@ -425,8 +425,8 @@ public: return map(std::log2); } Vectorized nextafter(const Vectorized &b) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_b[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); i++) { @@ -490,8 +490,8 @@ public: return this->sqrt().reciprocal(); } Vectorized pow(const Vectorized &exp) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_exp[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_exp[size()]; store(tmp); exp.store(tmp_exp); for (int64_t i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h index 6f85988bcf4..86cf42556d1 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h @@ -3,8 +3,8 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include +#include +#include #include namespace at { @@ -55,7 +55,7 @@ public: } template static Vectorized blend(Vectorized a, Vectorized b) { - __at_align32__ int64_t tmp_values[size()]; + __at_align__ int64_t tmp_values[size()]; a.store(tmp_values); if (mask & 0x01) tmp_values[0] = _mm256_extract_epi64(b.values, 0); @@ -93,7 +93,7 @@ public: return _mm256_loadu_si256(reinterpret_cast(ptr)); } static Vectorized loadu(const void* ptr, int64_t count) { - __at_align32__ int64_t tmp_values[size()]; + __at_align__ int64_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -109,7 +109,7 @@ public: // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align32__ int64_t tmp_values[size()]; + __at_align__ int64_t tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(int64_t)); } @@ -216,7 +216,7 @@ public: return _mm256_loadu_si256(reinterpret_cast(ptr)); } static Vectorized loadu(const void* ptr, int32_t count) { - __at_align32__ int32_t tmp_values[size()]; + __at_align__ int32_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -232,7 +232,7 @@ public: // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align32__ int32_t tmp_values[size()]; + __at_align__ int32_t tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); } @@ -346,7 +346,7 @@ public: } template static Vectorized blend(Vectorized a, Vectorized b) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; a.store(tmp_values); if (mask & 0x01) tmp_values[0] = _mm256_extract_epi16(b.values, 0); @@ -436,7 +436,7 @@ public: return _mm256_loadu_si256(reinterpret_cast(ptr)); } static Vectorized loadu(const void* ptr, int16_t count) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -452,7 +452,7 @@ public: // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); } @@ -527,7 +527,7 @@ public: } template static Vectorized blend(Vectorized a, Vectorized b) { - __at_align32__ int8_t tmp_values[size()]; + __at_align__ int8_t tmp_values[size()]; a.store(tmp_values); if (mask & 0x01) tmp_values[0] = _mm256_extract_epi8(b.values, 0); @@ -685,7 +685,7 @@ public: return _mm256_loadu_si256(reinterpret_cast(ptr)); } static Vectorized loadu(const void* ptr, int8_t count) { - __at_align32__ int8_t tmp_values[size()]; + __at_align__ int8_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -701,7 +701,7 @@ public: // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align32__ int8_t tmp_values[size()]; + __at_align__ int8_t tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(int8_t)); } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index 874be68e523..dc5e8331273 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -3,8 +3,8 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include +#include +#include #include #include #include @@ -39,7 +39,7 @@ namespace at { namespace vec { namespace { -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) struct Vectorizedqi { protected: @@ -53,7 +53,6 @@ struct Vectorizedqi { } }; -#if defined(CPU_CAPABILITY_AVX2) template __m256i pack_saturate_and_clamp( __m256i first, @@ -94,7 +93,6 @@ __m256i pack_saturate_and_clamp( _mm256_set1_epi8(min_val), _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val))); } -#endif template inline void __attribute__((always_inline)) QuantizeAvx2( @@ -103,7 +101,6 @@ inline void __attribute__((always_inline)) QuantizeAvx2( int len, float inverse_scale, int64_t zero_point) { -#if defined(CPU_CAPABILITY_AVX2) constexpr int VLEN = 8; constexpr auto min_val = std::numeric_limits::min(); constexpr auto max_val = std::numeric_limits::max(); @@ -212,10 +209,6 @@ inline void __attribute__((always_inline)) QuantizeAvx2( std::min(std::max(transformed, float(min_val)), float(max_val)); dst[i] = clipped; } -#else - at::native::quantize_vec( - 1.0f / inverse_scale, zero_point, src, reinterpret_cast(dst), len); -#endif } template<> @@ -266,11 +259,7 @@ struct Vectorized : public Vectorizedqi { Vectorized zero_point, Vectorized scale_zp_premul) const { __m256 float_vals = _mm256_cvtepi32_ps(vals); -#if defined(CPU_CAPABILITY_AVX2) return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; -#else - return {scale * (Vectorized(float_vals) - zero_point)}; -#endif } static Vectorized quantize( @@ -286,39 +275,11 @@ struct Vectorized : public Vectorizedqi { } Vectorized maximum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_max_epi32(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(int_vals.data()), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(b_vals.data()), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::max(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized minimum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epi32(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized relu(Vectorized zero_point) const { @@ -328,65 +289,24 @@ struct Vectorized : public Vectorizedqi { Vectorized relu6( Vectorized zero_point, Vectorized q_six) { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epi32( _mm256_max_epi32(vals, zero_point.vals), q_six.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array zero_point_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals); - std::array q_six_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min( - std::max(int_vals[i], zero_point_vals[i]), q_six_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } int_vec_return_type widening_subtract(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return {_mm256_sub_epi32(vals, b)}; -#else - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = int_vals[i] - b_vals[i]; - } - return {_mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals))}; -#endif } static Vectorized requantize_from_int( const int_vec_return_type& inp, float multiplier, int32_t zero_point) { -#ifdef CPU_CAPABILITY_AVX2 __m256 multiplier_v = _mm256_set1_ps(multiplier); __m256i zero_point_v = _mm256_set1_epi32(zero_point); __m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v); __m256i rounded = _mm256_cvtps_epi32(scaled); return _mm256_add_epi32(rounded, zero_point_v); -#else - std::array inp_vals; - inp[0].store(inp_vals.data()); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = - nearbyint(static_cast(inp_vals[i]) * multiplier) + - zero_point; - } - return loadu(result_vals.data()); -#endif } void dump() const { @@ -411,43 +331,16 @@ template <> Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_mullo_epi32(a, b); -#else - // Pray the compiler can autovectorize this - std::array::size()> a_vals; - std::array::size()> b_vals; - a.store(a_vals.data()); - b.store(b_vals.data()); - std::array::size()> result_vals; - for (size_t i = 0; i < std::decay_t::size(); ++i) { - result_vals[i] = a_vals[i] * b_vals[i]; - } - return Vectorized::loadu(result_vals.data()); -#endif } template <> Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_add_epi32(a, b); -#else - // Pray the compiler can autovectorize this - std::array::size()> a_vals; - std::array::size()> b_vals; - a.store(a_vals.data()); - b.store(b_vals.data()); - std::array::size()> result_vals; - for (size_t i = 0; i < std::decay_t::size(); ++i) { - result_vals[i] = a_vals[i] + b_vals[i]; - } - return Vectorized::loadu(result_vals.data()); -#endif } -#ifdef CPU_CAPABILITY_AVX2 /* * Convert values from int32 back to int8/uint8 */ @@ -493,7 +386,6 @@ __m256i RequantizeAvx2( xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); return xyzw_clamped_v; } -#endif template<> struct Vectorized : public Vectorizedqi { @@ -544,21 +436,7 @@ struct Vectorized : public Vectorizedqi { private: __m256i cvtepi8_epi32(__m128i epi8_vals) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_cvtepi8_epi32(epi8_vals); -#else // CPU_CAPABILITY_AVX2 - __m128i result_data[2]; - __m128i unpacked1 = _mm_unpacklo_epi8(epi8_vals, epi8_vals); - __m128i unpacked2 = _mm_unpacklo_epi16(unpacked1, unpacked1); - __m128i shifted1 = _mm_srli_si128(epi8_vals, 4); - __m128i shifted2 = _mm_srai_epi32(unpacked2, 24); - result_data[0] = shifted2; - __m128i unpacked3 = _mm_unpacklo_epi8(shifted1, shifted1); - __m128i unpacked4 = _mm_unpacklo_epi16(unpacked3, unpacked3); - __m128i shifted3 = _mm_srai_epi32(unpacked4, 24); - result_data[1] = shifted3; - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_data)); -#endif } public: @@ -576,7 +454,6 @@ struct Vectorized : public Vectorizedqi { __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); -#if defined(CPU_CAPABILITY_AVX2) auto val0 = vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul); auto val1 = @@ -585,12 +462,6 @@ struct Vectorized : public Vectorizedqi { vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul); auto val3 = vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul); -#else - auto val0 = scale * (Vectorized(float_val0) - zero_point); - auto val1 = scale * (Vectorized(float_val1) - zero_point); - auto val2 = scale * (Vectorized(float_val2) - zero_point); - auto val3 = scale * (Vectorized(float_val3) - zero_point); -#endif return {val0, val1, val2, val3}; } @@ -607,39 +478,11 @@ struct Vectorized : public Vectorizedqi { } Vectorized maximum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_max_epi8(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::max(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized minimum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epi8(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized relu(Vectorized zero_point) const { @@ -649,29 +492,11 @@ struct Vectorized : public Vectorizedqi { Vectorized relu6( Vectorized zero_point, Vectorized q_six) { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epi8( _mm256_max_epi8(vals, zero_point.vals), q_six.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array zero_point_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals); - std::array q_six_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min( - std::max(int_vals[i], zero_point_vals[i]), q_six_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } int_vec_return_type widening_subtract(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); @@ -701,55 +526,15 @@ struct Vectorized : public Vectorizedqi { Vectorized(res_1), Vectorized(res_2), Vectorized(res_3)}; -#else - // Pray the compiler can autovectorize this - std::array int_vals; - store(int_vals.data()); - std::array b_vals; - b.store(b_vals.data()); - constexpr int elem_per_int_vec = size() / int_num_vecs(); - int32_t rv[int_num_vecs()][elem_per_int_vec]; - for (size_t i = 0; i < int_num_vecs(); ++i) { - for (size_t j = 0; j < elem_per_int_vec; ++j) { - rv[i][j] = static_cast(int_vals[i * elem_per_int_vec + j]) - - static_cast(b_vals[i * elem_per_int_vec + j]); - } - } - return {Vectorized::loadu(rv[0]), - Vectorized::loadu(rv[1]), - Vectorized::loadu(rv[2]), - Vectorized::loadu(rv[3])}; -#endif } static Vectorized requantize_from_int( const int_vec_return_type& inp, float multiplier, int32_t zero_point) { -#ifdef CPU_CAPABILITY_AVX2 __m256 multiplier_v = _mm256_set1_ps(multiplier); __m256i zero_point_v = _mm256_set1_epi32(zero_point); return RequantizeAvx2(inp, multiplier_v, zero_point_v); -#else - // Pray the compiler can autovectorize this - constexpr int elem_per_int_vec = size() / int_num_vecs(); - constexpr auto min_val = std::numeric_limits::min(); - constexpr auto max_val = std::numeric_limits::max(); - int32_t rv[int_num_vecs()][elem_per_int_vec]; - for (size_t i = 0; i < int_num_vecs(); ++i) { - inp[i].store(rv[i]); - } - std::array result_vals; - for (size_t i = 0; i < int_num_vecs(); ++i) { - for (size_t j = 0; j < elem_per_int_vec; ++j) { - int32_t rounded = - nearbyint(static_cast(rv[i][j]) * multiplier) + zero_point; - result_vals[i * elem_per_int_vec + j] = - std::min(std::max(rounded, min_val), max_val); - } - } - return loadu(result_vals.data()); -#endif } void dump() const { @@ -817,20 +602,7 @@ struct Vectorized : public Vectorizedqi { private: __m256i cvtepu8_epi32(__m128i epu8_vals) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_cvtepu8_epi32(epu8_vals); -#else // CPU_CAPABILITY_AVX2 - __m128i result_data[2]; - __m128i zeros = _mm_setzero_si128(); - __m128i unpacked1 = _mm_unpacklo_epi8(epu8_vals, zeros); - __m128i unpacked2 = _mm_unpacklo_epi16(unpacked1, zeros); - result_data[0] = unpacked2; - __m128i shifted = _mm_srli_si128(epu8_vals, 4); - __m128i unpacked3 = _mm_unpacklo_epi8(shifted, zeros); - __m128i unpacked4 = _mm_unpacklo_epi16(unpacked3, zeros); - result_data[1] = unpacked4; - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_data)); -#endif } public: @@ -848,7 +620,6 @@ struct Vectorized : public Vectorizedqi { __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); -#if defined(CPU_CAPABILITY_AVX2) auto val0 = vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul); auto val1 = @@ -857,12 +628,6 @@ struct Vectorized : public Vectorizedqi { vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul); auto val3 = vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul); -#else - auto val0 = scale * (Vectorized(float_val0) - zero_point); - auto val1 = scale * (Vectorized(float_val1) - zero_point); - auto val2 = scale * (Vectorized(float_val2) - zero_point); - auto val3 = scale * (Vectorized(float_val3) - zero_point); -#endif return {val0, val1, val2, val3}; } @@ -879,39 +644,11 @@ struct Vectorized : public Vectorizedqi { } Vectorized maximum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_max_epu8(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::max(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized minimum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epu8(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized relu(Vectorized zero_point) const { @@ -921,29 +658,11 @@ struct Vectorized : public Vectorizedqi { Vectorized relu6( Vectorized zero_point, Vectorized q_six) { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epu8( _mm256_max_epu8(vals, zero_point.vals), q_six.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array zero_point_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals); - std::array q_six_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min( - std::max(int_vals[i], zero_point_vals[i]), q_six_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } int_vec_return_type widening_subtract(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); @@ -972,55 +691,15 @@ struct Vectorized : public Vectorizedqi { Vectorized(res_1), Vectorized(res_2), Vectorized(res_3)}; -#else - // Pray the compiler can autovectorize this - std::array int_vals; - std::array b_vals; - store(int_vals.data()); - b.store(b_vals.data()); - static constexpr int elem_per_int_vec = size() / int_num_vecs(); - int32_t rv[int_num_vecs()][elem_per_int_vec]; - for (size_t i = 0; i < int_num_vecs(); ++i) { - for (size_t j = 0; j < elem_per_int_vec; ++j) { - rv[i][j] = static_cast(int_vals[i * elem_per_int_vec + j]) - - static_cast(b_vals[i * elem_per_int_vec + j]); - } - } - return {Vectorized::loadu(rv[0]), - Vectorized::loadu(rv[1]), - Vectorized::loadu(rv[2]), - Vectorized::loadu(rv[3])}; -#endif } static Vectorized requantize_from_int( const int_vec_return_type& inp, float multiplier, int32_t zero_point) { -#ifdef CPU_CAPABILITY_AVX2 __m256 multiplier_v = _mm256_set1_ps(multiplier); __m256i zero_point_v = _mm256_set1_epi32(zero_point); return RequantizeAvx2(inp, multiplier_v, zero_point_v); -#else - // Pray the compiler can autovectorize this - constexpr int elem_per_int_vec = size() / int_num_vecs(); - constexpr auto min_val = std::numeric_limits::min(); - constexpr auto max_val = std::numeric_limits::max(); - int32_t rv[int_num_vecs()][elem_per_int_vec]; - for (size_t i = 0; i < int_num_vecs(); ++i) { - inp[i].store(rv[i]); - } - std::array result_vals; - for (size_t i = 0; i < int_num_vecs(); ++i) { - for (size_t j = 0; j < elem_per_int_vec; ++j) { - int32_t rounded = - nearbyint(static_cast(rv[i][j]) * multiplier) + zero_point; - result_vals[i * elem_per_int_vec + j] = - std::min(std::max(rounded, min_val), max_val); - } - } - return loadu(result_vals.data()); -#endif } void dump() const { @@ -1497,6 +1176,5 @@ Vectorized inline maximum(const Vectorized& a, const V return a.maximum(b); } -#endif // (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) - +#endif // if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) }}} diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h index 3c42b601645..3d798a7f626 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include #include diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h index ce59bae3f4f..888f2f0b932 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h @@ -1,6 +1,6 @@ #pragma once -#include -#include +#include +#include #include #include @@ -141,7 +141,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return { @@ -153,7 +153,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); std::memcpy( @@ -165,7 +165,7 @@ class Vectorized { ComplexDbl& operator[](int idx) = delete; Vectorized map(ComplexDbl (*const f)(ComplexDbl)) const { - __at_align32__ ComplexDbl tmp[size()]; + __at_align__ ComplexDbl tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -174,7 +174,7 @@ class Vectorized { } Vectorized map(ComplexDbl (*const f)(const ComplexDbl&)) const { - __at_align32__ ComplexDbl tmp[size()]; + __at_align__ ComplexDbl tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -455,8 +455,8 @@ class Vectorized { } Vectorized pow(const Vectorized& exp) const { - __at_align32__ ComplexDbl x_tmp[size()]; - __at_align32__ ComplexDbl y_tmp[size()]; + __at_align__ ComplexDbl x_tmp[size()]; + __at_align__ ComplexDbl y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (int i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h index f96488964bb..0aa726b9bfd 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include @@ -196,7 +196,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return { @@ -209,7 +209,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); std::memcpy( @@ -221,7 +221,7 @@ class Vectorized { ComplexFlt& operator[](int idx) = delete; Vectorized map(ComplexFlt (*const f)(ComplexFlt)) const { - __at_align32__ ComplexFlt tmp[size()]; + __at_align__ ComplexFlt tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -230,7 +230,7 @@ class Vectorized { } Vectorized map(ComplexFlt (*const f)(const ComplexFlt&)) const { - __at_align32__ ComplexFlt tmp[size()]; + __at_align__ ComplexFlt tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -434,8 +434,8 @@ class Vectorized { } Vectorized pow(const Vectorized& exp) const { - __at_align32__ ComplexFlt x_tmp[size()]; - __at_align32__ ComplexFlt y_tmp[size()]; + __at_align__ ComplexFlt x_tmp[size()]; + __at_align__ ComplexFlt y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (int i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h index ac0a131878a..29616182fe1 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include @@ -169,7 +169,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -179,7 +179,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h index 5fd1fb9afc8..2427276bcea 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include namespace at { @@ -180,7 +180,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -190,7 +190,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h index 16a535fd1d1..bd179883c9b 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include namespace at { namespace vec { @@ -269,7 +269,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -279,7 +279,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h index 759c4973965..460f49cbc8d 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include namespace at { namespace vec { @@ -199,7 +199,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -209,7 +209,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h index d2fbf4d51cf..fea09402965 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include namespace at { namespace vec { @@ -148,7 +148,7 @@ class Vectorized { (vint64)vec_vsx_ld(offset16, dptr)}; } - __at_align32__ double tmp_values[size()]; + __at_align__ double tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return { @@ -161,7 +161,7 @@ class Vectorized { vec_vsx_st((vfloat64)_vec0, offset0, dptr); vec_vsx_st((vfloat64)_vec1, offset16, dptr); } else if (count > 0) { - __at_align32__ double tmp_values[size()]; + __at_align__ double tmp_values[size()]; vec_vsx_st((vfloat64)_vec0, offset0, tmp_values); vec_vsx_st((vfloat64)_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h index ac2047d75ba..ed457b9adef 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include #include @@ -81,7 +81,7 @@ struct Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -91,7 +91,7 @@ struct Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h index 728ff51d71d..f2a8446cd0e 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include #include @@ -91,7 +91,7 @@ struct Vectorized { vec_vsx_ld(offset0, reinterpret_cast(ptr)), vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; } @@ -100,7 +100,7 @@ struct Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h index 4994abe9f13..c335ace0ced 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include #include @@ -92,7 +92,7 @@ struct Vectorized { vec_vsx_ld(offset0, reinterpret_cast(ptr)), vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; } @@ -101,7 +101,7 @@ struct Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h b/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h index 8bc943d62dc..afd21b09b45 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h @@ -1,7 +1,7 @@ #pragma once #include #include -#include +#include using vbool8 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char; using vbool16 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short; diff --git a/aten/src/ATen/cpu/vec/vec512/vec512.h b/aten/src/ATen/cpu/vec/vec512/vec512.h new file mode 100644 index 00000000000..6f53067c0ef --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512.h @@ -0,0 +1,195 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + + +#include +#include +#include +#include +#include + +namespace at { +namespace vec { + +// See Note [Acceptable use of anonymous namespace in header] +namespace { + + C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) { + stream << val.val_; + return stream; + } + C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) { + stream << static_cast(val.val_); + return stream; + } + C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) { + stream << static_cast(val.val_); + return stream; + } + +template +std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { + T buf[Vectorized::size()]; + vec.store(buf); + stream << "vec["; + for (int i = 0; i != Vectorized::size(); i++) { + if (i != 0) { + stream << ", "; + } + stream << buf[i]; + } + stream << "]"; + return stream; +} + + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template<> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castpd_ps(src); +} + +template<> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castps_pd(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t> +inline gather(const double* base_addr, const Vectorized& vindex) { + return _mm512_i64gather_pd(vindex, base_addr, scale); +} + +template +std::enable_if_t> +inline gather(const float* base_addr, const Vectorized& vindex) { + return _mm512_i32gather_ps(vindex, base_addr, scale); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t> +inline mask_gather(const Vectorized& src, const double* base_addr, + const Vectorized& vindex, const Vectorized& mask) { + auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF)); + auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ); + return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale); +} + +template +std::enable_if_t> +inline mask_gather(const Vectorized& src, const float* base_addr, + const Vectorized& vindex, const Vectorized& mask) { + auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF)); + auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ); + return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template<> +Vectorized +inline convert_to_int_of_same_size(const Vectorized &src) { + return _mm512_cvtpd_epi64(src); +} + +template<> +Vectorized +inline convert_to_int_of_same_size(const Vectorized &src) { + return _mm512_cvttps_epi32(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> +inline interleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, a1, a3, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + __m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4); + return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); +} + +template <> +std::pair, Vectorized> +inline interleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} + // b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} + // + // return: + // {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} + // {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} + __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, + 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, + 27, 11, 26, 10, 25, 9, 24, 8); + return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> +inline deinterleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // output: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + // The members of indices have been written in binary format for better understandability + __m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); + + return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); +} + +template <> +std::pair, Vectorized> +inline deinterleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} + // b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} + // output: + // return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} + // {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} + __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, + 14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, + 15, 13, 11, 9, 7, 5, 3, 1); + + return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); +} + +#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h new file mode 100644 index 00000000000..4a240bb36d3 --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h @@ -0,0 +1,879 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +static inline void cvtbf16_fp32(const __m256i& a, __m512& o) { + o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); +} + +static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) { + __m256i lo = _mm512_extracti32x8_epi32(a, 0); + __m256i hi = _mm512_extracti32x8_epi32(a, 1); + cvtbf16_fp32(lo, o1); + cvtbf16_fp32(hi, o2); +} + +static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) { + __m512i lo = _mm512_castps_si512(a); + __m512i hi = _mm512_castps_si512(b); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); + auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones); + auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_lo = _mm512_add_epi32(t_lo, vec_bias); + t_hi = _mm512_add_epi32(t_hi, vec_bias); + // input += rounding_bias; + t_lo = _mm512_add_epi32(t_lo, lo); + t_hi = _mm512_add_epi32(t_hi, hi); + // input = input >> 16; + t_lo = _mm512_srli_epi32(t_lo, 16); + t_hi = _mm512_srli_epi32(t_hi, 16); + // Check NaN before converting back to bf16 + t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo); + t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi); + + t_lo = _mm512_packus_epi32(t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4] + __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(idx, t_lo); +} + +static inline __m512i merge_compare_result(const __m512& a, const __m512& b) { + __m512i lo = _mm512_castps_si512(a); + __m512i hi = _mm512_castps_si512(b); + lo = _mm512_srli_epi32(lo, 16); + hi = _mm512_srli_epi32(hi, 16); + auto out = _mm512_packus_epi32(lo, hi); + __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(idx, out); +} + +template <> class Vectorized { +private: + __m512i values; +public: + using value_type = uint16_t; + using size_type = int; + static constexpr size_type size() { + return 32; + } + Vectorized() {} + Vectorized(__m512i v) : values(v) {} + Vectorized(BFloat16 val) { + value_type uw = val.x; + values = _mm512_set1_epi16(uw); + } + Vectorized(BFloat16 val1, BFloat16 val2, BFloat16 val3, BFloat16 val4, + BFloat16 val5, BFloat16 val6, BFloat16 val7, BFloat16 val8, + BFloat16 val9, BFloat16 val10, BFloat16 val11, BFloat16 val12, + BFloat16 val13, BFloat16 val14, BFloat16 val15, BFloat16 val16, + BFloat16 val17, BFloat16 val18, BFloat16 val19, BFloat16 val20, + BFloat16 val21, BFloat16 val22, BFloat16 val23, BFloat16 val24, + BFloat16 val25, BFloat16 val26, BFloat16 val27, BFloat16 val28, + BFloat16 val29, BFloat16 val30, BFloat16 val31, BFloat16 val32) { + values = _mm512_set_epi16( + val32.x, val31.x, val30.x, val29.x, val28.x, val27.x, val26.x, val25.x, + val24.x, val23.x, val22.x, val21.x, val20.x, val19.x, val18.x, val17.x, + val16.x, val15.x, val14.x, val13.x, val12.x, val11.x, val10.x, val9.x, + val8.x, val7.x, val6.x, val5.x, val4.x, val3.x, val2.x, val1.x); + } + operator __m512i() const { + return values; + } + BFloat16& operator[](int idx) = delete; + const BFloat16& operator[](int idx) const = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0)); + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int16_t count) { + __at_align__ int16_t tmp_values[size()]; + std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __at_align__ int16_t tmp_values[size()]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); + } + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + __at_align__ int16_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = b.values[31]; + if (mask & 0x02) + tmp_values[1] = b.values[30]; + if (mask & 0x04) + tmp_values[2] = b.values[29]; + if (mask & 0x08) + tmp_values[3] = b.values[28]; + if (mask & 0x10) + tmp_values[4] = b.values[27]; + if (mask & 0x20) + tmp_values[5] = b.values[26]; + if (mask & 0x40) + tmp_values[6] = b.values[25]; + if (mask & 0x80) + tmp_values[7] = b.values[24]; + if (mask & 0x100) + tmp_values[8] = b.values[23]; + if (mask & 0x200) + tmp_values[9] = b.values[22]; + if (mask & 0x400) + tmp_values[10] = b.values[21]; + if (mask & 0x800) + tmp_values[11] = b.values[20]; + if (mask & 0x1000) + tmp_values[12] = b.values[19]; + if (mask & 0x2000) + tmp_values[13] = b.values[18]; + if (mask & 0x4000) + tmp_values[14] = b.values[17]; + if (mask & 0x8000) + tmp_values[15] = b.values[16]; + if (mask & 0x10000) + tmp_values[16] = b.values[15]; + if (mask & 0x20000) + tmp_values[17] = b.values[14]; + if (mask & 0x40000) + tmp_values[18] = b.values[13]; + if (mask & 0x80000) + tmp_values[19] = b.values[12]; + if (mask & 0x100000) + tmp_values[20] = b.values[11]; + if (mask & 0x200000) + tmp_values[21] = b.values[10]; + if (mask & 0x400000) + tmp_values[22] = b.values[9]; + if (mask & 0x800000) + tmp_values[23] = b.values[8]; + if (mask & 0x1000000) + tmp_values[24] = b.values[7]; + if (mask & 0x2000000) + tmp_values[25] = b.values[6]; + if (mask & 0x4000000) + tmp_values[26] = b.values[5]; + if (mask & 0x8000000) + tmp_values[27] = b.values[4]; + if (mask & 0x10000000) + tmp_values[28] = b.values[3]; + if (mask & 0x20000000) + tmp_values[29] = b.values[2]; + if (mask & 0x40000000) + tmp_values[30] = b.values[1]; + if (mask & 0x80000000) + tmp_values[31] = b.values[0]; + return loadu(tmp_values); + } + static Vectorized blendv(const Vectorized& a, + const Vectorized& b, const Vectorized& mask) { + auto all_ones = _mm512_set1_epi16(0xFFFF); + auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi16(mask_, a.values, b.values); + } + template + static Vectorized arange(BFloat16 base = 0.f, step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, + base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, + base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step, + base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step, + base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step, + base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step, + base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step); + } + static Vectorized set(const Vectorized& a, + const Vectorized& b, int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + case 16: + return blend<65535>(a, b); + case 17: + return blend<131071>(a, b); + case 18: + return blend<262143>(a, b); + case 19: + return blend<524287>(a, b); + case 20: + return blend<1048575>(a, b); + case 21: + return blend<2097151>(a, b); + case 22: + return blend<4194303>(a, b); + case 23: + return blend<8388607>(a, b); + case 24: + return blend<16777215>(a, b); + case 25: + return blend<33554431>(a, b); + case 26: + return blend<67108863>(a, b); + case 27: + return blend<134217727>(a, b); + case 28: + return blend<268435455>(a, b); + case 29: + return blend<536870911>(a, b); + case 30: + return blend<1073741823>(a, b); + case 31: + return blend<2147483647>(a, b); + } + return b; + } + Vectorized map(const __m512 (*const vop)(__m512)) const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + const auto o1 = vop(lo); + const auto o2 = vop(hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized abs() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + const auto mask = _mm512_set1_ps(-0.f); + const auto o1 = _mm512_andnot_ps(mask, lo); + const auto o2 = _mm512_andnot_ps(mask, hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized angle() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto angle_lambda = [](__m512 values) { + const auto zero_vec = _mm512_set1_ps(0.f); + const auto nan_vec = _mm512_set1_ps(NAN); + const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); + const auto non_nan_mask_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec), + not_nan_mask, 0xFFFFFFFF); + const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(non_nan_mask_vec), + zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_ps(c10::pi); + + const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); + return angle; + }; + auto o1 = angle_lambda(lo); + auto o2 = angle_lambda(hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return map(Sleef_acosf16_u10); + } + Vectorized asin() const { + return map(Sleef_asinf16_u10); + } + Vectorized atan() const { + return map(Sleef_atanf16_u10); + } + Vectorized atan2(const Vectorized &b) const { + __m512 lo, hi; + __m512 b1, b2; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(b.values, b1, b2); + auto o1 = Sleef_atan2f16_u10(lo, b1); + auto o2 = Sleef_atan2f16_u10(hi, b2); + return cvtfp32_bf16(o1, o2); + } + Vectorized copysign(const Vectorized &sign) const { + // copy sign bit (0x8000) from sign and remaining bits from values + __m512i mask_value = _mm512_set1_epi32(~0x80008000); + __m512i mask_signbit = _mm512_set1_epi32(0x80008000); + return Vectorized( + _mm512_or_si512( + _mm512_and_si512(values, mask_value), + _mm512_and_si512(sign, mask_signbit))); + } + Vectorized erf() const { + return map(Sleef_erff16_u10); + } + Vectorized erfc() const { + return map(Sleef_erfcf16_u15); + } + Vectorized erfinv() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_erfinv(tmp1[i]); + tmp2[i] = calc_erfinv(tmp2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + Vectorized exp() const { + return map(Sleef_expf16_u10); + } + Vectorized expm1() const { + return map(Sleef_expm1f16_u10); + } + Vectorized fmod(const Vectorized & q) const { + __m512 x_lo, x_hi; + cvtbf16_fp32(values, x_lo, x_hi); + __m512 q_lo, q_hi; + cvtbf16_fp32(q.values, q_lo, q_hi); + auto o1 = Sleef_fmodf16(x_lo, q_lo); + auto o2 = Sleef_fmodf16(x_hi, q_hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized hypot(const Vectorized &b) const { + __m512 lo, hi; + __m512 b1, b2; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(b.values, b1, b2); + auto o1 = Sleef_hypotf16_u05(lo, b1); + auto o2 = Sleef_hypotf16_u05(hi, b2); + return cvtfp32_bf16(o1, o2); + } + Vectorized i0() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_i0(tmp1[i]); + tmp2[i] = calc_i0(tmp2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + Vectorized i0e() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_i0e(tmp1[i]); + tmp2[i] = calc_i0e(tmp2[i]); + } + const auto o1 = _mm512_loadu_ps(tmp1); + const auto o2 = _mm512_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + Vectorized igamma(const Vectorized &x) const { + __m512 lo, hi; + __m512 xlo, xhi; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + + Vectorized igammac(const Vectorized &x) const { + __m512 lo, hi; + __m512 xlo, xhi; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + Vectorized log() const { + return map(Sleef_logf16_u10); + } + Vectorized log2() const { + return map(Sleef_log2f16_u10); + } + Vectorized log10() const { + return map(Sleef_log10f16_u10); + } + Vectorized log1p() const { + return map(Sleef_log1pf16_u10); + } + Vectorized frac() const; + Vectorized sin() const { + return map(Sleef_sinf16_u10); + } + Vectorized sinh() const { + return map(Sleef_sinhf16_u10); + } + Vectorized cos() const { + return map(Sleef_cosf16_u10); + } + Vectorized cosh() const { + return map(Sleef_coshf16_u10); + } + Vectorized ceil() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto o1 = _mm512_ceil_ps(lo); + auto o2 = _mm512_ceil_ps(hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized floor() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto o1 = _mm512_floor_ps(lo); + auto o2 = _mm512_floor_ps(hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized neg() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto mask = _mm512_set1_ps(-0.f); + auto o1 = _mm512_xor_ps(mask, lo); + auto o2 = _mm512_xor_ps(mask, hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized round() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return cvtfp32_bf16(o1, o2); + } + Vectorized tan() const { + return map(Sleef_tanf16_u10); + } + Vectorized tanh() const { + return map(Sleef_tanhf16_u10); + } + Vectorized trunc() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + return cvtfp32_bf16(o1, o2); + } + Vectorized lgamma() const { + return map(Sleef_lgammaf16_u10); + } + Vectorized sqrt() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto o1 = _mm512_sqrt_ps(lo); + auto o2 = _mm512_sqrt_ps(hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized reciprocal() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto ones = _mm512_set1_ps(1); + auto o1 = _mm512_div_ps(ones, lo); + auto o2 = _mm512_div_ps(ones, hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized rsqrt() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto ones = _mm512_set1_ps(1); + auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo)); + auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi)); + return cvtfp32_bf16(o1, o2); + } + Vectorized pow(const Vectorized &b) const { + __m512 lo, hi; + __m512 b1, b2; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(b.values, b1, b2); + auto o1 = Sleef_powf16_u10(lo, b1); + auto o2 = Sleef_powf16_u10(hi, b2); + return cvtfp32_bf16(o1, o2); + } + + Vectorized inline operator>(const Vectorized& other) const; + Vectorized inline operator<(const Vectorized& other) const; + Vectorized inline operator>=(const Vectorized& other) const; + Vectorized inline operator<=(const Vectorized& other) const; + Vectorized inline operator==(const Vectorized& other) const; + Vectorized inline operator!=(const Vectorized& other) const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template +Vectorized static inline bfloat16_binary_op_as_fp32(const Vectorized& a, + const Vectorized& b, Op op) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvtfp32_bf16(o1, o2); +} + +template +Vectorized static inline bfloat16_compare_as_fp32(const Vectorized& a, + const Vectorized& b, Op op) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return merge_compare_result(o1, o2); +} + +Vectorized inline Vectorized::operator>(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} +Vectorized inline Vectorized::operator<(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} +Vectorized inline Vectorized::operator>=(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} +Vectorized inline Vectorized::operator<=(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} +Vectorized inline Vectorized::operator==(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} +Vectorized inline Vectorized::operator!=(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} + +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); }); +} +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); }); +} +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); }); +} +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); }); +} + +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_si512(a, b); +} +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_si512(a, b); +} +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_si512(a, b); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto max_lo = _mm512_max_ps(a_lo, b_lo); + auto max_hi = _mm512_max_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask)); + auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(max_lo, nan_lo); + auto o2 = _mm512_or_ps(max_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512i zero_vec = _mm512_set1_epi32(0); + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto min_lo = _mm512_min_ps(a_lo, b_lo); + auto min_hi = _mm512_min_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask, + 0xFFFFFFFF)); + auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask, + 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(min_lo, nan_lo); + auto o2 = _mm512_or_ps(min_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp(const Vectorized& a, + const Vectorized& min, const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + __m512 max_lo, max_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(min), min_lo, min_hi); + cvtbf16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo)); + auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi)); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 max_lo, max_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, a_lo); + auto o2 = _mm512_min_ps(max_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(min), min_lo, min_hi); + auto o1 = _mm512_max_ps(min_lo, a_lo); + auto o2 = _mm512_max_ps(min_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + int64_t i; +#pragma unroll + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); + _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); + } +#pragma unroll + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, + const Vectorized& b, const Vectorized& c) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512 c_lo, c_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + cvtbf16_fp32(__m512i(c), c_lo, c_hi); + auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_bf16(o1, o2); +} + +inline std::tuple, Vectorized> convert_bfloat16_float(const Vectorized& a) { + __m512 o1, o2; + cvtbf16_fp32(__m512i(a), o1, o2); + return std::make_tuple(o1, o2); +} + +inline Vectorized convert_float_bfloat16(const Vectorized& a, const Vectorized& b) { + return cvtfp32_bf16(__m512(a), __m512(b)); +} + +#else //defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +inline std::tuple, Vectorized> convert_bfloat16_float(const Vectorized& a) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; + a.store(arr2); + convert(arr2, arr, K); + return std::make_tuple( + Vectorized::loadu(arr), + Vectorized::loadu(arr + Vectorized::size())); +} + +inline Vectorized convert_float_bfloat16(const Vectorized& a, const Vectorized& b) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; + a.store(arr); + b.store(arr + Vectorized::size()); + convert(arr, arr2, K); + return Vectorized::loadu(arr2); +} + +#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out) { + auto values = _mm256_loadu_si256(reinterpret_cast(data)); + __m512 out_values; + cvtbf16_fp32(values, out_values); + out = out_values; +} + +void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out1, Vectorized& out2) { + auto vec = Vectorized::loadu(data); + __m512 out1_values, out2_values; + cvtbf16_fp32(vec, out1_values, out2_values); + out1 = out1_values; + out2 = out2_values; +} +#else // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (int k = 0; k < Vectorized::size(); ++k) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out1, Vectorized& out2) { + load_fp32_from_bf16(data, out1); + data += Vectorized::size(); + load_fp32_from_bf16(data, out2); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h new file mode 100644 index 00000000000..6fc22f0f7d3 --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h @@ -0,0 +1,526 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> class Vectorized> { +private: + __m512d values; + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; +public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(__m512d v) : values(v) {} + Vectorized(c10::complex val) { + double real_value = val.real(); + double imag_value = val.imag(); + values = _mm512_setr_pd(real_value, imag_value, real_value, imag_value, + real_value, imag_value, real_value, imag_value); + } + Vectorized(c10::complex val1, c10::complex val2, + c10::complex val3, c10::complex val4) { + values = _mm512_setr_pd(val1.real(), val1.imag(), + val2.real(), val2.imag(), + val3.real(), val3.imag(), + val4.real(), val4.imag()); + } + operator __m512d() const { + return values; + } + template + static Vectorized> blend(const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + // NOLINTNEXTLINE(clang-diagnostic-warning) + switch (mask) { + case 0: + return a; + case 1: + return _mm512_mask_blend_pd(0x03, a.values, b.values); //b0000 0001 = b0000 0011 + case 2: + return _mm512_mask_blend_pd(0x0C, a.values, b.values); //b0000 0010 = b0000 1100 + case 3: + return _mm512_mask_blend_pd(0x0F, a.values, b.values); //b0000 0011 = b0000 1111 + case 4: + return _mm512_mask_blend_pd(0x30, a.values, b.values); //b0000 0100 = b0011 0000 + case 5: + return _mm512_mask_blend_pd(0x33, a.values, b.values); //b0000 0101 = b0011 0011 + case 6: + return _mm512_mask_blend_pd(0x3C, a.values, b.values); //b0000 0110 = b0011 1100 + case 7: + return _mm512_mask_blend_pd(0x3F, a.values, b.values); //b0000 0111 = b0011 1111 + case 8: + return _mm512_mask_blend_pd(0xC0, a.values, b.values); //b0000 1000 = b1100 0000 + case 9: + return _mm512_mask_blend_pd(0xC3, a.values, b.values); //b0000 1001 = b1100 0011 + case 10: + return _mm512_mask_blend_pd(0xCC, a.values, b.values); //b0000 1010 = b1100 1100 + case 11: + return _mm512_mask_blend_pd(0xCF, a.values, b.values); //b0000 1011 = b1100 1111 + case 12: + return _mm512_mask_blend_pd(0xF0, a.values, b.values); //b0000 1100 = b1111 0000 + case 13: + return _mm512_mask_blend_pd(0xF3, a.values, b.values); //b0000 1101 = b1111 0011 + case 14: + return _mm512_mask_blend_pd(0xFC, a.values, b.values); //b0000 1110 = b1111 1100 + case 15: + return _mm512_mask_blend_pd(0xFF, a.values, b.values); //b0000 1111 = b1111 1111 + } + return b; + } + static Vectorized> blendv(const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm512_unpacklo_pd(mask.values, mask.values); + auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask_), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mmask, a.values, b.values); + } + template + static Vectorized> arange(c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>(base, + base + c10::complex(1)*step, + base + c10::complex(2)*step, + base + c10::complex(3)*step); + } + static Vectorized> set(const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized> loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[2*size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < 2*size(); ++i) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm512_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[2*size()]; + _mm512_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + // AVX512 doesn't have horizontal add & horizontal sub instructions. + // TODO: hadd_pd() & hsub_pd() may have scope for improvement. + static inline __m512d hadd_pd(__m512d a, __m512d b) { + __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); + return _mm512_add_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); + } + static inline __m512d hsub_pd(__m512d a, __m512d b) { + __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); + return _mm512_sub_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); + } + __m512d abs_2_() const { + auto val_2 = _mm512_mul_pd(values, values); // a*a b*b + return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b + } + __m512d abs_() const { + return _mm512_sqrt_pd(abs_2_()); // abs abs + } + Vectorized> abs() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); + return _mm512_and_pd(abs_(), real_mask); // abs 0 + } + __m512d angle_() const { + //angle = atan2(b/a) + auto b_a = _mm512_permute_pd(values, 0x55); // b a + return Sleef_atan2d8_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); + auto angle = _mm512_permute_pd(angle_(), 0x55); // angle 90-angle + return _mm512_and_pd(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm512_setzero_pd(); + auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ); + auto mask_vec = _mm512_mask_set1_epi64(_mm512_castpd_si512(zero), mask, + 0xFFFFFFFFFFFFFFFF); + auto abs_val = Vectorized(abs); + + auto div = values / abs_val.values; // x / abs(x) + + return blendv(div, zero, _mm512_castsi512_pd(mask_vec)); + } + __m512d real_() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); + return _mm512_and_pd(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m512d imag_() const { + const __m512d imag_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0x0000000000000000, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0xFFFFFFFFFFFFFFFF)); + return _mm512_and_pd(values, imag_mask); + } + Vectorized> imag() const { + return _mm512_permute_pd(imag_(), 0x55); //b a + } + __m512d conj_() const { + const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + return _mm512_xor_pd(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number performance. + return map(std::log); + } + Vectorized> log2() const { + const __m512d log2_ = _mm512_set1_pd(std::log(2)); + return _mm512_div_pd(log(), log2_); + } + Vectorized> log10() const { + const __m512d log10_ = _mm512_set1_pd(std::log(10)); + return _mm512_div_pd(log(), log10_); + } + Vectorized> log1p() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + const __m512d one = _mm512_set1_pd(1); + + auto conj = conj_(); + auto b_a = _mm512_permute_pd(conj, 0x55); //-b a + auto ab = _mm512_mul_pd(conj, b_a); //-ab -ab + auto im = _mm512_add_pd(ab, ab); //-2ab -2ab + + auto val_2 = _mm512_mul_pd(values, values); // a*a b*b + auto re = hsub_pd(val_2, _mm512_permute_pd(val_2, 0x55)); // a*a-b*b b*b-a*a + re = _mm512_sub_pd(one, re); + + auto root = Vectorized(_mm512_mask_blend_pd(0xAA, re, im)).sqrt(); //sqrt(re + i*im) + auto ln = Vectorized(_mm512_add_pd(b_a, root)).log(); //ln(iz + sqrt()) + return Vectorized(_mm512_permute_pd(ln.values, 0x55)).conj(); //-i*ln() + } + Vectorized> acos() const { + // acos(x) = pi/2 - asin(x) + constexpr auto pi_2d = c10::pi / 2; + const __m512d pi_2 = _mm512_setr_pd(pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0); + return _mm512_sub_pd(pi_2, asin()); + } + Vectorized> atan() const; + Vectorized> atan2(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> erf() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> erfc() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> exp() const { + //exp(a + bi) + // = exp(a)*(cos(b) + sin(b)i) + auto exp = Sleef_expd8_u10(values); //exp(a) exp(b) + exp = _mm512_mask_blend_pd(0xAA, exp, _mm512_permute_pd(exp, 0x55)); //exp(a) exp(a) + + auto sin_cos = Sleef_sincosd8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)] + auto cos_sin = _mm512_mask_blend_pd(0xAA, _mm512_permute_pd(sin_cos.y, 0x55), + sin_cos.x); //cos(b) sin(b) + return _mm512_mul_pd(exp, cos_sin); + } + Vectorized> expm1() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm512_ceil_pd(values); + } + Vectorized> floor() const { + return _mm512_floor_pd(values); + } + Vectorized> hypot(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> igamma(const Vectorized> &x) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> igammac(const Vectorized> &x) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> neg() const { + auto zero = _mm512_setzero_pd(); + return _mm512_sub_pd(zero, values); + } + Vectorized> nextafter(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> round() const { + return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow(const Vectorized> &exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (int i = 0; i < size(); i++) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==(const Vectorized>& other) const { + auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask, + 0xFFFFFFFFFFFFFFFF)); + } + Vectorized> operator!=(const Vectorized>& other) const { + auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask, + 0xFFFFFFFFFFFFFFFF)); + } + Vectorized> operator<(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq(const Vectorized>& other) const; + Vectorized> ne(const Vectorized>& other) const; + Vectorized> lt(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> le(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> gt(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> ge(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } +}; + +template <> Vectorized> inline operator+(const Vectorized> &a, + const Vectorized> &b) { + return _mm512_add_pd(a, b); +} + +template <> Vectorized> inline operator-(const Vectorized> &a, + const Vectorized> &b) { + return _mm512_sub_pd(a, b); +} + +template <> Vectorized> inline operator*(const Vectorized> &a, + const Vectorized> &b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm512_mul_pd(a, b); //ac bd + + auto d_c = _mm512_permute_pd(b, 0x55); //d c + d_c = _mm512_xor_pd(sign_mask, d_c); //d -c + auto ad_bc = _mm512_mul_pd(a, d_c); //ad -bc + + auto ret = Vectorized>::hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc + return ret; +} + +template <> Vectorized> inline operator/(const Vectorized> &a, + const Vectorized> &b) { + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() + //im = (bc - ad)/abs_2() + const __m512d sign_mask = _mm512_setr_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0); + auto ac_bd = _mm512_mul_pd(a, b); //ac bd + + auto d_c = _mm512_permute_pd(b, 0x55); //d c + d_c = _mm512_xor_pd(sign_mask, d_c); //-d c + auto ad_bc = _mm512_mul_pd(a, d_c); //-ad bc + + auto re_im = Vectorized>::hadd_pd(ac_bd, ad_bc);//ac + bd bc - ad + return _mm512_div_pd(re_im, b.abs_2_()); +} + +// reciprocal. Implement this here so we can use multiplication. +Vectorized> Vectorized>::reciprocal() const{ + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() = c/abs_2() + //im = (bc - ad)/abs_2() = d/abs_2() + const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto c_d = _mm512_xor_pd(sign_mask, values); //c -d + return _mm512_div_pd(c_d, abs_2_()); +} + +Vectorized> Vectorized>::atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + const __m512d i = _mm512_setr_pd(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + const Vectorized i_half = _mm512_setr_pd(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5); + + auto sum = Vectorized(_mm512_add_pd(i, values)); // a 1+b + auto sub = Vectorized(_mm512_sub_pd(i, values)); // -a 1-b + auto ln = (sum/sub).log(); // ln((i + z)/(i - z)) + return i_half*ln; // i/2*ln() +} + +template <> +Vectorized> inline maximum(const Vectorized>& a, + const Vectorized>& b) { + auto zero_vec = _mm512_set1_epi64(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm512_mask_blend_pd(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, + 0xFFFFFFFFFFFFFFFF); + return _mm512_or_pd(max, _mm512_castsi512_pd(isnan)); +} + +template <> +Vectorized> inline minimum(const Vectorized>& a, + const Vectorized>& b) { + auto zero_vec = _mm512_set1_epi64(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm512_mask_blend_pd(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, + 0xFFFFFFFFFFFFFFFF); + return _mm512_or_pd(min, _mm512_castsi512_pd(isnan)); +} + +template <> +Vectorized> inline operator&(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_and_pd(a, b); +} + +template <> +Vectorized> inline operator|(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_or_pd(a, b); +} + +template <> +Vectorized> inline operator^(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_xor_pd(a, b); +} + +Vectorized> Vectorized>::eq(const Vectorized>& other) const { + return (*this == other) & Vectorized>(_mm512_set1_pd(1.0)); +} + +Vectorized> Vectorized>::ne(const Vectorized>& other) const { + return (*this != other) & Vectorized>(_mm512_set1_pd(1.0)); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h new file mode 100644 index 00000000000..dfd070604c4 --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h @@ -0,0 +1,1030 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> class Vectorized> { +private: + __m512 values; + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; +public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + Vectorized(__m512 v) : values(v) {} + Vectorized(c10::complex val) { + float real_value = val.real(); + float imag_value = val.imag(); + values = _mm512_setr_ps(real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value); + } + Vectorized(c10::complex val1, c10::complex val2, + c10::complex val3, c10::complex val4, + c10::complex val5, c10::complex val6, + c10::complex val7, c10::complex val8) { + values = _mm512_setr_ps(val1.real(), val1.imag(), + val2.real(), val2.imag(), + val3.real(), val3.imag(), + val4.real(), val4.imag(), + val5.real(), val5.imag(), + val6.real(), val6.imag(), + val7.real(), val7.imag(), + val8.real(), val8.imag()); + } + operator __m512() const { + return values; + } + template + static Vectorized> blend(const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + // NOLINTNEXTLINE(clang-diagnostic-warning) + // The compiler would hopefully convert this switch condition + // into a jump table + switch (mask) { + case 0: + return a; + case 1: + return _mm512_mask_blend_ps(0x03, a.values, b.values); + case 2: + return _mm512_mask_blend_ps(0x0C, a.values, b.values); + case 3: + return _mm512_mask_blend_ps(0x0F, a.values, b.values); + case 4: + return _mm512_mask_blend_ps(0x30, a.values, b.values); + case 5: + return _mm512_mask_blend_ps(0x33, a.values, b.values); + case 6: + return _mm512_mask_blend_ps(0x3C, a.values, b.values); + case 7: + return _mm512_mask_blend_ps(0x3F, a.values, b.values); + case 8: + return _mm512_mask_blend_ps(0xC0, a.values, b.values); + case 9: + return _mm512_mask_blend_ps(0xC3, a.values, b.values); + case 10: + return _mm512_mask_blend_ps(0xCC, a.values, b.values); + case 11: + return _mm512_mask_blend_ps(0xCF, a.values, b.values); + case 12: + return _mm512_mask_blend_ps(0xF0, a.values, b.values); + case 13: + return _mm512_mask_blend_ps(0xF3, a.values, b.values); + case 14: + return _mm512_mask_blend_ps(0xFC, a.values, b.values); + case 15: + return _mm512_mask_blend_ps(0xFF, a.values, b.values); + case 16: + return _mm512_mask_blend_ps(0x300, a.values, b.values); + case 17: + return _mm512_mask_blend_ps(0x303, a.values, b.values); + case 18: + return _mm512_mask_blend_ps(0x30C, a.values, b.values); + case 19: + return _mm512_mask_blend_ps(0x30F, a.values, b.values); + case 20: + return _mm512_mask_blend_ps(0x330, a.values, b.values); + case 21: + return _mm512_mask_blend_ps(0x333, a.values, b.values); + case 22: + return _mm512_mask_blend_ps(0x33C, a.values, b.values); + case 23: + return _mm512_mask_blend_ps(0x33F, a.values, b.values); + case 24: + return _mm512_mask_blend_ps(0x3C0, a.values, b.values); + case 25: + return _mm512_mask_blend_ps(0x3C3, a.values, b.values); + case 26: + return _mm512_mask_blend_ps(0x3CC, a.values, b.values); + case 27: + return _mm512_mask_blend_ps(0x3CF, a.values, b.values); + case 28: + return _mm512_mask_blend_ps(0x3F0, a.values, b.values); + case 29: + return _mm512_mask_blend_ps(0x3F3, a.values, b.values); + case 30: + return _mm512_mask_blend_ps(0x3FC, a.values, b.values); + case 31: + return _mm512_mask_blend_ps(0x3FF, a.values, b.values); + case 32: + return _mm512_mask_blend_ps(0xC00, a.values, b.values); + case 33: + return _mm512_mask_blend_ps(0xC03, a.values, b.values); + case 34: + return _mm512_mask_blend_ps(0xC0C, a.values, b.values); + case 35: + return _mm512_mask_blend_ps(0xC0F, a.values, b.values); + case 36: + return _mm512_mask_blend_ps(0xC30, a.values, b.values); + case 37: + return _mm512_mask_blend_ps(0xC33, a.values, b.values); + case 38: + return _mm512_mask_blend_ps(0xC3C, a.values, b.values); + case 39: + return _mm512_mask_blend_ps(0xC3F, a.values, b.values); + case 40: + return _mm512_mask_blend_ps(0xCC0, a.values, b.values); + case 41: + return _mm512_mask_blend_ps(0xCC3, a.values, b.values); + case 42: + return _mm512_mask_blend_ps(0xCCC, a.values, b.values); + case 43: + return _mm512_mask_blend_ps(0xCCF, a.values, b.values); + case 44: + return _mm512_mask_blend_ps(0xCF0, a.values, b.values); + case 45: + return _mm512_mask_blend_ps(0xCF3, a.values, b.values); + case 46: + return _mm512_mask_blend_ps(0xCFC, a.values, b.values); + case 47: + return _mm512_mask_blend_ps(0xCFF, a.values, b.values); + case 48: + return _mm512_mask_blend_ps(0xF00, a.values, b.values); + case 49: + return _mm512_mask_blend_ps(0xF03, a.values, b.values); + case 50: + return _mm512_mask_blend_ps(0xF0C, a.values, b.values); + case 51: + return _mm512_mask_blend_ps(0xF0F, a.values, b.values); + case 52: + return _mm512_mask_blend_ps(0xF30, a.values, b.values); + case 53: + return _mm512_mask_blend_ps(0xF33, a.values, b.values); + case 54: + return _mm512_mask_blend_ps(0xF3C, a.values, b.values); + case 55: + return _mm512_mask_blend_ps(0xF3F, a.values, b.values); + case 56: + return _mm512_mask_blend_ps(0xFC0, a.values, b.values); + case 57: + return _mm512_mask_blend_ps(0xFC3, a.values, b.values); + case 58: + return _mm512_mask_blend_ps(0xFCC, a.values, b.values); + case 59: + return _mm512_mask_blend_ps(0xFCF, a.values, b.values); + case 60: + return _mm512_mask_blend_ps(0xFF0, a.values, b.values); + case 61: + return _mm512_mask_blend_ps(0xFF3, a.values, b.values); + case 62: + return _mm512_mask_blend_ps(0xFFC, a.values, b.values); + case 63: + return _mm512_mask_blend_ps(0xFFF, a.values, b.values); + case 64: + return _mm512_mask_blend_ps(0x3000, a.values, b.values); + case 65: + return _mm512_mask_blend_ps(0x3003, a.values, b.values); + case 66: + return _mm512_mask_blend_ps(0x300C, a.values, b.values); + case 67: + return _mm512_mask_blend_ps(0x300F, a.values, b.values); + case 68: + return _mm512_mask_blend_ps(0x3030, a.values, b.values); + case 69: + return _mm512_mask_blend_ps(0x3033, a.values, b.values); + case 70: + return _mm512_mask_blend_ps(0x303C, a.values, b.values); + case 71: + return _mm512_mask_blend_ps(0x303F, a.values, b.values); + case 72: + return _mm512_mask_blend_ps(0x30C0, a.values, b.values); + case 73: + return _mm512_mask_blend_ps(0X30C3, a.values, b.values); + case 74: + return _mm512_mask_blend_ps(0x30CC, a.values, b.values); + case 75: + return _mm512_mask_blend_ps(0x30CF, a.values, b.values); + case 76: + return _mm512_mask_blend_ps(0x30F0, a.values, b.values); + case 77: + return _mm512_mask_blend_ps(0x30F3, a.values, b.values); + case 78: + return _mm512_mask_blend_ps(0x30FC, a.values, b.values); + case 79: + return _mm512_mask_blend_ps(0x30FF, a.values, b.values); + case 80: + return _mm512_mask_blend_ps(0x3300, a.values, b.values); + case 81: + return _mm512_mask_blend_ps(0X3303, a.values, b.values); + case 82: + return _mm512_mask_blend_ps(0x330C, a.values, b.values); + case 83: + return _mm512_mask_blend_ps(0x330F, a.values, b.values); + case 84: + return _mm512_mask_blend_ps(0x3330, a.values, b.values); + case 85: + return _mm512_mask_blend_ps(0x3333, a.values, b.values); + case 86: + return _mm512_mask_blend_ps(0x333C, a.values, b.values); + case 87: + return _mm512_mask_blend_ps(0X333F, a.values, b.values); + case 88: + return _mm512_mask_blend_ps(0x33C0, a.values, b.values); + case 89: + return _mm512_mask_blend_ps(0x33C3, a.values, b.values); + case 90: + return _mm512_mask_blend_ps(0x33CC, a.values, b.values); + case 91: + return _mm512_mask_blend_ps(0x33CF, a.values, b.values); + case 92: + return _mm512_mask_blend_ps(0x33F0, a.values, b.values); + case 93: + return _mm512_mask_blend_ps(0x33F3, a.values, b.values); + case 94: + return _mm512_mask_blend_ps(0x33FC, a.values, b.values); + case 95: + return _mm512_mask_blend_ps(0x33FF, a.values, b.values); + case 96: + return _mm512_mask_blend_ps(0X3C00, a.values, b.values); + case 97: + return _mm512_mask_blend_ps(0x3C03, a.values, b.values); + case 98: + return _mm512_mask_blend_ps(0x3C0C, a.values, b.values); + case 99: + return _mm512_mask_blend_ps(0x3C0F, a.values, b.values); + case 100: + return _mm512_mask_blend_ps(0x3C30, a.values, b.values); + case 101: + return _mm512_mask_blend_ps(0x3C33, a.values, b.values); + case 102: + return _mm512_mask_blend_ps(0x3C3C, a.values, b.values); + case 103: + return _mm512_mask_blend_ps(0x3C3F, a.values, b.values); + case 104: + return _mm512_mask_blend_ps(0x3CC0, a.values, b.values); + case 105: + return _mm512_mask_blend_ps(0x3CC3, a.values, b.values); + case 106: + return _mm512_mask_blend_ps(0x3CCC, a.values, b.values); + case 107: + return _mm512_mask_blend_ps(0x3CCF, a.values, b.values); + case 108: + return _mm512_mask_blend_ps(0x3CF0, a.values, b.values); + case 109: + return _mm512_mask_blend_ps(0x3CF3, a.values, b.values); + case 110: + return _mm512_mask_blend_ps(0x3CFC, a.values, b.values); + case 111: + return _mm512_mask_blend_ps(0x3CFF, a.values, b.values); + case 112: + return _mm512_mask_blend_ps(0x3F00, a.values, b.values); + case 113: + return _mm512_mask_blend_ps(0x3F03, a.values, b.values); + case 114: + return _mm512_mask_blend_ps(0x3F0C, a.values, b.values); + case 115: + return _mm512_mask_blend_ps(0x3F0F, a.values, b.values); + case 116: + return _mm512_mask_blend_ps(0x3F30, a.values, b.values); + case 117: + return _mm512_mask_blend_ps(0x3F33, a.values, b.values); + case 118: + return _mm512_mask_blend_ps(0x3F3C, a.values, b.values); + case 119: + return _mm512_mask_blend_ps(0x3F3F, a.values, b.values); + case 120: + return _mm512_mask_blend_ps(0x3FC0, a.values, b.values); + case 121: + return _mm512_mask_blend_ps(0x3FC3, a.values, b.values); + case 122: + return _mm512_mask_blend_ps(0x3FCC, a.values, b.values); + case 123: + return _mm512_mask_blend_ps(0x3FCF, a.values, b.values); + case 124: + return _mm512_mask_blend_ps(0x3FF0, a.values, b.values); + case 125: + return _mm512_mask_blend_ps(0x3FF3, a.values, b.values); + case 126: + return _mm512_mask_blend_ps(0x3FFC, a.values, b.values); + case 127: + return _mm512_mask_blend_ps(0x3FFF, a.values, b.values); + case 128: + return _mm512_mask_blend_ps(0xC000, a.values, b.values); + case 129: + return _mm512_mask_blend_ps(0xC003, a.values, b.values); + case 130: + return _mm512_mask_blend_ps(0xC00C, a.values, b.values); + case 131: + return _mm512_mask_blend_ps(0xC00F, a.values, b.values); + case 132: + return _mm512_mask_blend_ps(0xC030, a.values, b.values); + case 133: + return _mm512_mask_blend_ps(0xC033, a.values, b.values); + case 134: + return _mm512_mask_blend_ps(0xC03C, a.values, b.values); + case 135: + return _mm512_mask_blend_ps(0xC03F, a.values, b.values); + case 136: + return _mm512_mask_blend_ps(0xC0C0, a.values, b.values); + case 137: + return _mm512_mask_blend_ps(0xC0C3, a.values, b.values); + case 138: + return _mm512_mask_blend_ps(0xC0CC, a.values, b.values); + case 139: + return _mm512_mask_blend_ps(0xC0CF, a.values, b.values); + case 140: + return _mm512_mask_blend_ps(0xC0F0, a.values, b.values); + case 141: + return _mm512_mask_blend_ps(0xC0F3, a.values, b.values); + case 142: + return _mm512_mask_blend_ps(0xC0FC, a.values, b.values); + case 143: + return _mm512_mask_blend_ps(0xC0FF, a.values, b.values); + case 144: + return _mm512_mask_blend_ps(0xC300, a.values, b.values); + case 145: + return _mm512_mask_blend_ps(0xC303, a.values, b.values); + case 146: + return _mm512_mask_blend_ps(0xC30C, a.values, b.values); + case 147: + return _mm512_mask_blend_ps(0xC30F, a.values, b.values); + case 148: + return _mm512_mask_blend_ps(0xC330, a.values, b.values); + case 149: + return _mm512_mask_blend_ps(0xC333, a.values, b.values); + case 150: + return _mm512_mask_blend_ps(0xC33C, a.values, b.values); + case 151: + return _mm512_mask_blend_ps(0xC33F, a.values, b.values); + case 152: + return _mm512_mask_blend_ps(0xC3C0, a.values, b.values); + case 153: + return _mm512_mask_blend_ps(0xC3C3, a.values, b.values); + case 154: + return _mm512_mask_blend_ps(0xC3CC, a.values, b.values); + case 155: + return _mm512_mask_blend_ps(0xC3CF, a.values, b.values); + case 156: + return _mm512_mask_blend_ps(0xC3F0, a.values, b.values); + case 157: + return _mm512_mask_blend_ps(0xC3F3, a.values, b.values); + case 158: + return _mm512_mask_blend_ps(0xC3FC, a.values, b.values); + case 159: + return _mm512_mask_blend_ps(0xC3FF, a.values, b.values); + case 160: + return _mm512_mask_blend_ps(0xCC00, a.values, b.values); + case 161: + return _mm512_mask_blend_ps(0xCC03, a.values, b.values); + case 162: + return _mm512_mask_blend_ps(0xCC0C, a.values, b.values); + case 163: + return _mm512_mask_blend_ps(0xCC0F, a.values, b.values); + case 164: + return _mm512_mask_blend_ps(0xCC30, a.values, b.values); + case 165: + return _mm512_mask_blend_ps(0xCC33, a.values, b.values); + case 166: + return _mm512_mask_blend_ps(0xCC3C, a.values, b.values); + case 167: + return _mm512_mask_blend_ps(0xCC3F, a.values, b.values); + case 168: + return _mm512_mask_blend_ps(0xCCC0, a.values, b.values); + case 169: + return _mm512_mask_blend_ps(0xCCC3, a.values, b.values); + case 170: + return _mm512_mask_blend_ps(0xCCCC, a.values, b.values); + case 171: + return _mm512_mask_blend_ps(0xCCCF, a.values, b.values); + case 172: + return _mm512_mask_blend_ps(0xCCF0, a.values, b.values); + case 173: + return _mm512_mask_blend_ps(0xCCF3, a.values, b.values); + case 174: + return _mm512_mask_blend_ps(0xCCFC, a.values, b.values); + case 175: + return _mm512_mask_blend_ps(0xCCFF, a.values, b.values); + case 176: + return _mm512_mask_blend_ps(0xCF00, a.values, b.values); + case 177: + return _mm512_mask_blend_ps(0xCF03, a.values, b.values); + case 178: + return _mm512_mask_blend_ps(0xCF0C, a.values, b.values); + case 179: + return _mm512_mask_blend_ps(0xCF0F, a.values, b.values); + case 180: + return _mm512_mask_blend_ps(0xCF30, a.values, b.values); + case 181: + return _mm512_mask_blend_ps(0xCF33, a.values, b.values); + case 182: + return _mm512_mask_blend_ps(0xCF3C, a.values, b.values); + case 183: + return _mm512_mask_blend_ps(0xCF3F, a.values, b.values); + case 184: + return _mm512_mask_blend_ps(0xCFC0, a.values, b.values); + case 185: + return _mm512_mask_blend_ps(0xCFC3, a.values, b.values); + case 186: + return _mm512_mask_blend_ps(0xCFCC, a.values, b.values); + case 187: + return _mm512_mask_blend_ps(0xCFCF, a.values, b.values); + case 188: + return _mm512_mask_blend_ps(0xCFF0, a.values, b.values); + case 189: + return _mm512_mask_blend_ps(0xCFF3, a.values, b.values); + case 190: + return _mm512_mask_blend_ps(0xCFFC, a.values, b.values); + case 191: + return _mm512_mask_blend_ps(0xCFFF, a.values, b.values); + case 192: + return _mm512_mask_blend_ps(0xF000, a.values, b.values); + case 193: + return _mm512_mask_blend_ps(0xF003, a.values, b.values); + case 194: + return _mm512_mask_blend_ps(0xF00C, a.values, b.values); + case 195: + return _mm512_mask_blend_ps(0xF00F, a.values, b.values); + case 196: + return _mm512_mask_blend_ps(0xF030, a.values, b.values); + case 197: + return _mm512_mask_blend_ps(0xF033, a.values, b.values); + case 198: + return _mm512_mask_blend_ps(0xF03C, a.values, b.values); + case 199: + return _mm512_mask_blend_ps(0xF03F, a.values, b.values); + case 200: + return _mm512_mask_blend_ps(0XF0C0, a.values, b.values); + case 201: + return _mm512_mask_blend_ps(0xF0C3, a.values, b.values); + case 202: + return _mm512_mask_blend_ps(0xF0CC, a.values, b.values); + case 203: + return _mm512_mask_blend_ps(0xF0CF, a.values, b.values); + case 204: + return _mm512_mask_blend_ps(0xF0F0, a.values, b.values); + case 205: + return _mm512_mask_blend_ps(0xF0F3, a.values, b.values); + case 206: + return _mm512_mask_blend_ps(0xF0FC, a.values, b.values); + case 207: + return _mm512_mask_blend_ps(0xF0FF, a.values, b.values); + case 208: + return _mm512_mask_blend_ps(0XF300, a.values, b.values); + case 209: + return _mm512_mask_blend_ps(0xF303, a.values, b.values); + case 210: + return _mm512_mask_blend_ps(0xF30C, a.values, b.values); + case 211: + return _mm512_mask_blend_ps(0xF30F, a.values, b.values); + case 212: + return _mm512_mask_blend_ps(0xF330, a.values, b.values); + case 213: + return _mm512_mask_blend_ps(0xF333, a.values, b.values); + case 214: + return _mm512_mask_blend_ps(0XF33C, a.values, b.values); + case 215: + return _mm512_mask_blend_ps(0xF33F, a.values, b.values); + case 216: + return _mm512_mask_blend_ps(0xF3C0, a.values, b.values); + case 217: + return _mm512_mask_blend_ps(0xF3C3, a.values, b.values); + case 218: + return _mm512_mask_blend_ps(0xF3CC, a.values, b.values); + case 219: + return _mm512_mask_blend_ps(0xF3CF, a.values, b.values); + case 220: + return _mm512_mask_blend_ps(0xF3F0, a.values, b.values); + case 221: + return _mm512_mask_blend_ps(0xF3F3, a.values, b.values); + case 222: + return _mm512_mask_blend_ps(0xF3FC, a.values, b.values); + case 223: + return _mm512_mask_blend_ps(0XF3FF, a.values, b.values); + case 224: + return _mm512_mask_blend_ps(0xFC00, a.values, b.values); + case 225: + return _mm512_mask_blend_ps(0xFC03, a.values, b.values); + case 226: + return _mm512_mask_blend_ps(0xFC0C, a.values, b.values); + case 227: + return _mm512_mask_blend_ps(0xFC0F, a.values, b.values); + case 228: + return _mm512_mask_blend_ps(0xFC30, a.values, b.values); + case 229: + return _mm512_mask_blend_ps(0xFC33, a.values, b.values); + case 230: + return _mm512_mask_blend_ps(0xFC3C, a.values, b.values); + case 231: + return _mm512_mask_blend_ps(0xFC3F, a.values, b.values); + case 232: + return _mm512_mask_blend_ps(0xFCC0, a.values, b.values); + case 233: + return _mm512_mask_blend_ps(0xFCC3, a.values, b.values); + case 234: + return _mm512_mask_blend_ps(0xFCCC, a.values, b.values); + case 235: + return _mm512_mask_blend_ps(0xFCCF, a.values, b.values); + case 236: + return _mm512_mask_blend_ps(0xFCF0, a.values, b.values); + case 237: + return _mm512_mask_blend_ps(0xFCF3, a.values, b.values); + case 238: + return _mm512_mask_blend_ps(0xFCFC, a.values, b.values); + case 239: + return _mm512_mask_blend_ps(0xFCFF, a.values, b.values); + case 240: + return _mm512_mask_blend_ps(0xFF00, a.values, b.values); + case 241: + return _mm512_mask_blend_ps(0xFF03, a.values, b.values); + case 242: + return _mm512_mask_blend_ps(0xFF0C, a.values, b.values); + case 243: + return _mm512_mask_blend_ps(0xFF0F, a.values, b.values); + case 244: + return _mm512_mask_blend_ps(0xFF30, a.values, b.values); + case 245: + return _mm512_mask_blend_ps(0xFF33, a.values, b.values); + case 246: + return _mm512_mask_blend_ps(0xFF3C, a.values, b.values); + case 247: + return _mm512_mask_blend_ps(0xFF3F, a.values, b.values); + case 248: + return _mm512_mask_blend_ps(0xFFC0, a.values, b.values); + case 249: + return _mm512_mask_blend_ps(0xFFC3, a.values, b.values); + case 250: + return _mm512_mask_blend_ps(0xFFCC, a.values, b.values); + case 251: + return _mm512_mask_blend_ps(0xFFCF, a.values, b.values); + case 252: + return _mm512_mask_blend_ps(0xFFF0, a.values, b.values); + case 253: + return _mm512_mask_blend_ps(0xFFF3, a.values, b.values); + case 254: + return _mm512_mask_blend_ps(0xFFFC, a.values, b.values); + } + return b; + } + static Vectorized> blendv(const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm512_unpacklo_ps(mask.values, mask.values); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask_), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mmask, a.values, b.values); + } + template + static Vectorized> arange(c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>(base, + base + step, + base + c10::complex(2)*step, + base + c10::complex(3)*step, + base + c10::complex(4)*step, + base + c10::complex(5)*step, + base + c10::complex(6)*step, + base + c10::complex(7)*step); + } + static Vectorized> set(const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized> loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_ps(reinterpret_cast(ptr)); + + __at_align__ float tmp_values[2*size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < 2*size(); ++i) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm512_load_ps(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[2*size()]; + _mm512_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + // AVX512 doesn't have horizontal add & horizontal sub instructions. + // TODO: hadd_pd() & hsub_pd() may have scope for improvement. + static inline __m512 hadd_ps(__m512 a, __m512 b) { + __m512i idx1 = _mm512_set_epi32(30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1); + return _mm512_add_ps(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); + } + static inline __m512 hsub_ps(__m512 a, __m512 b) { + __m512i idx1 = _mm512_set_epi32(30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1); + return _mm512_sub_ps(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m512 abs_2_() const { + auto val_2 = _mm512_mul_ps(values, values); // a*a b*b + auto ret = hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b + return ret; + } + __m512 abs_() const { + return _mm512_sqrt_ps(abs_2_()); // abs abs + } + Vectorized> abs() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); + return _mm512_and_ps(abs_(), real_mask); // abs 0 + } + __m512 angle_() const { + //angle = atan2(b/a) + auto b_a = _mm512_permute_ps(values, 0xB1); // b a + return Sleef_atan2f16_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); + auto angle = _mm512_permute_ps(angle_(), 0xB1); // angle 90-angle + return _mm512_and_ps(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm512_setzero_ps(); + auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ); + auto abs_val = Vectorized(abs); + + auto div = values / abs_val.values; // x / abs(x) + + return _mm512_mask_blend_ps(mask, div, zero); + } + __m512 real_() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); + return _mm512_and_ps(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m512 imag_() const { + const __m512 imag_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF)); + return _mm512_and_ps(values, imag_mask); + } + Vectorized> imag() const { + return _mm512_permute_ps(imag_(), 0xB1); //b a + } + __m512 conj_() const { + const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + return _mm512_xor_ps(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number performance. + return map(std::log); + } + Vectorized> log2() const { + const __m512 log2_ = _mm512_set1_ps(std::log(2)); + return _mm512_div_ps(log(), log2_); + } + Vectorized> log10() const { + const __m512 log10_ = _mm512_set1_ps(std::log(10)); + return _mm512_div_ps(log(), log10_); + } + Vectorized> log1p() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + const __m512 one = _mm512_set1_ps(1); + + auto conj = conj_(); + auto b_a = _mm512_permute_ps(conj, 0xB1); //-b a + auto ab = _mm512_mul_ps(conj, b_a); //-ab -ab + auto im = _mm512_add_ps(ab, ab); //-2ab -2ab + + auto val_2 = _mm512_mul_ps(values, values); // a*a b*b + auto re = hsub_ps(val_2, _mm512_permute_ps(val_2, 0xB1)); // a*a-b*b b*b-a*a + re = _mm512_sub_ps(one, re); + + auto root = Vectorized(_mm512_mask_blend_ps(0xAAAA, re, im)).sqrt(); //sqrt(re + i*im) + auto ln = Vectorized(_mm512_add_ps(b_a, root)).log(); //ln(iz + sqrt()) + return Vectorized(_mm512_permute_ps(ln.values, 0xB1)).conj(); //-i*ln() + } + Vectorized> acos() const { + return map(std::acos); + } + Vectorized> atan() const; + Vectorized> atan2(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> erf() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> erfc() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> exp() const { + //exp(a + bi) + // = exp(a)*(cos(b) + sin(b)i) + auto exp = Sleef_expf16_u10(values); //exp(a) exp(b) + exp = _mm512_mask_blend_ps(0xAAAA, exp, _mm512_permute_ps(exp, 0xB1)); //exp(a) exp(a) + + auto sin_cos = Sleef_sincosf16_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)] + auto cos_sin = _mm512_mask_blend_ps(0xAAAA, _mm512_permute_ps(sin_cos.y, 0xB1), + sin_cos.x); //cos(b) sin(b) + return _mm512_mul_ps(exp, cos_sin); + } + Vectorized> expm1() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm512_ceil_ps(values); + } + Vectorized> floor() const { + return _mm512_floor_ps(values); + } + Vectorized> hypot(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> igamma(const Vectorized> &x) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> igammac(const Vectorized> &x) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> neg() const { + auto zero = _mm512_setzero_ps(); + return _mm512_sub_ps(zero, values); + } + Vectorized> nextafter(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> round() const { + return _mm512_roundscale_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm512_roundscale_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow(const Vectorized> &exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (int i = 0; i < size(); i++) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==(const Vectorized>& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF)); + } + Vectorized> operator!=(const Vectorized>& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF)); + } + Vectorized> operator<(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq(const Vectorized>& other) const; + Vectorized> ne(const Vectorized>& other) const; + Vectorized> lt(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> le(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> gt(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> ge(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } +}; + +template <> Vectorized> inline operator+(const Vectorized> &a, + const Vectorized> &b) { + return _mm512_add_ps(a, b); +} + +template <> Vectorized> inline operator-(const Vectorized> &a, + const Vectorized> &b) { + return _mm512_sub_ps(a, b); +} + +template <> Vectorized> inline operator*(const Vectorized> &a, + const Vectorized> &b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm512_mul_ps(a, b); //ac bd + + auto d_c = _mm512_permute_ps(b, 0xB1); //d c + d_c = _mm512_xor_ps(sign_mask, d_c); //d -c + auto ad_bc = _mm512_mul_ps(a, d_c); //ad -bc + + auto ret = Vectorized>::hsub_ps(ac_bd, ad_bc); //ac - bd ad + bc + return ret; +} + +template <> Vectorized> inline operator/(const Vectorized> &a, + const Vectorized> &b) { + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() + //im = (bc - ad)/abs_2() + const __m512 sign_mask = _mm512_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0); + auto ac_bd = _mm512_mul_ps(a, b); //ac bd + + auto d_c = _mm512_permute_ps(b, 0xB1); //d c + d_c = _mm512_xor_ps(sign_mask, d_c); //-d c + auto ad_bc = _mm512_mul_ps(a, d_c); //-ad bc + + auto re_im = Vectorized>::hadd_ps(ac_bd, ad_bc);//ac + bd bc - ad + return _mm512_div_ps(re_im, b.abs_2_()); +} + +// reciprocal. Implement this here so we can use multiplication. +Vectorized> Vectorized>::reciprocal() const { + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() = c/abs_2() + //im = (bc - ad)/abs_2() = d/abs_2() + const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto c_d = _mm512_xor_ps(sign_mask, values); //c -d + return _mm512_div_ps(c_d, abs_2_()); +} + +Vectorized> Vectorized>::atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + const __m512 i = _mm512_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + const Vectorized i_half = _mm512_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, + 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5); + + auto sum = Vectorized(_mm512_add_ps(i, values)); // a 1+b + auto sub = Vectorized(_mm512_sub_ps(i, values)); // -a 1-b + auto ln = (sum/sub).log(); // ln((i + z)/(i - z)) + return i_half*ln; // i/2*ln() +} + +template <> +Vectorized> inline maximum(const Vectorized>& a, + const Vectorized>& b) { + auto zero_vector = _mm512_set1_epi32(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm512_mask_blend_ps(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF); + return _mm512_or_ps(max, _mm512_castsi512_ps(isnan)); +} + +template <> +Vectorized> inline minimum(const Vectorized>& a, + const Vectorized>& b) { + auto zero_vector = _mm512_set1_epi32(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm512_mask_blend_ps(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF); + return _mm512_or_ps(min, _mm512_castsi512_ps(isnan)); +} + +template <> +Vectorized> inline operator&(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_and_ps(a, b); +} + +template <> +Vectorized> inline operator|(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_or_ps(a, b); +} + +template <> +Vectorized> inline operator^(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_xor_ps(a, b); +} + +Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + return (*this == other) & Vectorized>(_mm512_set1_ps(1.0f)); +} + +Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + return (*this != other) & Vectorized>(_mm512_set1_ps(1.0f)); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_double.h new file mode 100644 index 00000000000..7128219748a --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_double.h @@ -0,0 +1,454 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#if (defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> class Vectorized { +private: + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; +public: + // values needs to be public for compilation with clang + // as vec512.h uses it + __m512d values; + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + Vectorized(__m512d v) : values(v) {} + Vectorized(double val) { + values = _mm512_set1_pd(val); + } + Vectorized(double val1, double val2, double val3, double val4, + double val5, double val6, double val7, double val8) { + values = _mm512_setr_pd(val1, val2, val3, val4, val5, val6, val7, val8); + } + operator __m512d() const { + return values; + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + return _mm512_mask_blend_pd(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask.values), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mmask, a.values, b.values); + } + template + static Vectorized arange(double base = 0., step_t step = static_cast(1)) { + return Vectorized(base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, + base + 7 * step); + } + static Vectorized set(const Vectorized& a, const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_pd(reinterpret_cast(ptr)); + + + __at_align__ double tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < size(); ++i) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(double)); + return _mm512_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[size()]; + _mm512_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(double)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + __mmask8 cmp = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_EQ_OQ); + return static_cast(cmp); + } + Vectorized isnan() const { + auto cmp_mask = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_UNORD_Q); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + Vectorized map(double (*const f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm512_set1_pd(-0.f); + return _mm512_andnot_pd(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm512_castsi512_pd(zero_vector); + const auto nan_vec = _mm512_set1_pd(NAN); + const auto not_nan_mask = _mm512_cmp_pd_mask(values, values, _CMP_EQ_OQ); + const auto not_nan = _mm512_mask_set1_epi64(zero_vector, not_nan_mask, + 0xFFFFFFFFFFFFFFFF); + const auto nan_mask = _mm512_cmp_pd_mask(_mm512_castsi512_pd(not_nan), + zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_pd(c10::pi); + + const auto neg_mask = _mm512_cmp_pd_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_pd(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_pd(nan_mask, angle, nan_vec); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_pd(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosd8_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asind8_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atand8_u10(values)); + } + Vectorized atan2(const Vectorized &b) const { + return Vectorized(Sleef_atan2d8_u10(values, b)); + } + Vectorized copysign(const Vectorized &sign) const { + return Vectorized(Sleef_copysignd8(values, sign)); + } + Vectorized erf() const { + return Vectorized(Sleef_erfd8_u10(values)); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcd8_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expd8_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1d8_u10(values)); + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodd8(values, q)); + } + Vectorized hypot(const Vectorized &b) const { + return Vectorized(Sleef_hypotd8_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized igamma(const Vectorized &x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized &x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized log() const { + return Vectorized(Sleef_logd8_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2d8_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10d8_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pd8_u10(values)); + } + Vectorized sin() const { + return Vectorized(Sleef_sind8_u10(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhd8_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosd8_u10(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshd8_u10(values)); + } + Vectorized ceil() const { + return _mm512_ceil_pd(values); + } + Vectorized floor() const { + return _mm512_floor_pd(values); + } + Vectorized frac() const; + Vectorized neg() const { + return _mm512_xor_pd(_mm512_set1_pd(-0.), values); + } + Vectorized nextafter(const Vectorized &b) const { + return Vectorized(Sleef_nextafterd8(values, b)); + } + Vectorized round() const { + return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tand8_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhd8_u10(values)); + } + Vectorized trunc() const { + return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammad8_u10(values)); + } + Vectorized sqrt() const { + return _mm512_sqrt_pd(values); + } + Vectorized reciprocal() const { + return _mm512_div_pd(_mm512_set1_pd(1), values); + } + Vectorized rsqrt() const { + return _mm512_div_pd(_mm512_set1_pd(1), _mm512_sqrt_pd(values)); + } + Vectorized pow(const Vectorized &b) const { + return Vectorized(Sleef_powd8_u10(values, b)); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator!=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator<(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LT_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator<=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LE_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator>(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GT_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator>=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GE_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_pd(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_pd(a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return _mm512_mul_pd(a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return _mm512_div_pd(a, b); +} + +// frac. Implement this here so we can use subtraction. +Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + auto zero_vec = _mm512_set1_epi64(0); + Vectorized max = _mm512_max_pd(a, b); + auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask, + 0xFFFFFFFFFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_pd(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + auto zero_vec = _mm512_set1_epi64(0); + Vectorized min = _mm512_min_pd(a, b); + auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask, + 0xFFFFFFFFFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_pd(min, isnan); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { + return _mm512_min_pd(max, _mm512_max_pd(min, a)); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + return _mm512_max_pd(min, a); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + return _mm512_min_pd(max, a); +} + +template <> +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_pd(a, b); +} + +template <> +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_pd(a, b); +} + +template <> +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_pd(a, b); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + int64_t i; +#pragma unroll + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + _mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i)); + } +#pragma unroll + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { + return _mm512_fmadd_pd(a, b, c); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h new file mode 100644 index 00000000000..1a2b113de9d --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h @@ -0,0 +1,469 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> class Vectorized { +private: + static constexpr __m512i zero_vec {0, 0, 0, 0, 0, 0, 0, 0}; +public: + __m512 values; + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized() {} + Vectorized(__m512 v) : values(v) {} + Vectorized(float val) { + values = _mm512_set1_ps(val); + } + Vectorized(float val1, float val2, float val3, float val4, + float val5, float val6, float val7, float val8, + float val9, float val10, float val11, float val12, + float val13, float val14, float val15, float val16) { + values = _mm512_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8, + val9, val10, val11, val12, val13, val14, val15, val16); + } + operator __m512() const { + return values; + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + return _mm512_mask_blend_ps(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mmask, a.values, b.values); + } + template + static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, + base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, + base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step); + } + static Vectorized set(const Vectorized& a, const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_ps(reinterpret_cast(ptr)); + __at_align__ float tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < size(); ++i) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, reinterpret_cast(ptr), count * sizeof(float)); + return _mm512_loadu_ps(tmp_values); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + _mm512_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[size()]; + _mm512_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float)); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + __mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ); + return static_cast(cmp); + } + Vectorized isnan() const { + auto mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm512_set1_ps(-0.f); + return _mm512_andnot_ps(mask, values); + } + Vectorized angle() const { + __m512 zero_vec = _mm512_set1_ps(0.f); + const auto nan_vec = _mm512_set1_ps(NAN); + const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); + const auto not_nan_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec), + not_nan_mask, 0xFFFFFFFF); + const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(not_nan_vec), + zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_ps(c10::pi); + + const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_ps(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosf16_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asinf16_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atanf16_u10(values)); + } + Vectorized atan2(const Vectorized &b) const { + return Vectorized(Sleef_atan2f16_u10(values, b)); + } + Vectorized copysign(const Vectorized &sign) const { + return Vectorized(Sleef_copysignf16(values, sign)); + } + Vectorized erf() const { + return Vectorized(Sleef_erff16_u10(values)); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcf16_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expf16_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1f16_u10(values)); + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodf16(values, q)); + } + Vectorized log() const { + return Vectorized(Sleef_logf16_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2f16_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10f16_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pf16_u10(values)); + } + Vectorized frac() const; + Vectorized sin() const { + return Vectorized(Sleef_sinf16_u10(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhf16_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosf16_u10(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshf16_u10(values)); + } + Vectorized ceil() const { + return _mm512_ceil_ps(values); + } + Vectorized floor() const { + return _mm512_floor_ps(values); + } + Vectorized hypot(const Vectorized &b) const { + return Vectorized(Sleef_hypotf16_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized igamma(const Vectorized &x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized &x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized neg() const { + return _mm512_xor_ps(_mm512_set1_ps(-0.f), values); + } + Vectorized nextafter(const Vectorized &b) const { + return Vectorized(Sleef_nextafterf16(values, b)); + } + Vectorized round() const { + return _mm512_roundscale_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tanf16_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhf16_u10(values)); + } + Vectorized trunc() const { + return _mm512_roundscale_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammaf16_u10(values)); + } + Vectorized sqrt() const { + return _mm512_sqrt_ps(values); + } + Vectorized reciprocal() const { + return _mm512_div_ps(_mm512_set1_ps(1), values); + } + Vectorized rsqrt() const { + return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values)); + } + Vectorized pow(const Vectorized &b) const { + return Vectorized(Sleef_powf16_u10(values, b)); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_ps(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_ps(a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return _mm512_mul_ps(a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return _mm512_div_ps(a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + auto zero_vec = _mm512_set1_epi32(0); + auto max = _mm512_max_ps(a, b); + auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask, + 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_ps(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + auto zero_vec = _mm512_set1_epi32(0); + auto min = _mm512_min_ps(a, b); + auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask, + 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_ps(min, isnan); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { + return _mm512_min_ps(max, _mm512_max_ps(min, a)); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + return _mm512_min_ps(max, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + return _mm512_max_ps(min, a); +} + +template <> +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_ps(a, b); +} + +template <> +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_ps(a, b); +} + +template <> +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_ps(a, b); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + int64_t i; +#pragma unroll + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + _mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i)); + } +#pragma unroll + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { + return _mm512_fmadd_ps(a, b, c); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h new file mode 100644 index 00000000000..cc866c065bf --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h @@ -0,0 +1,1173 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +namespace at { +namespace vec { +namespace { + +#ifdef CPU_CAPABILITY_AVX512 + +struct Vectorizedi { +protected: + __m512i values; + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; + static inline __m512i invert(const __m512i& v) { + const auto ones = _mm512_set1_epi64(-1); + return _mm512_xor_si512(ones, v); + } +public: + Vectorizedi() {} + Vectorizedi(__m512i v) : values(v) {} + operator __m512i() const { + return values; + } +}; + +#else + +struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined + +#endif // CPU_CAPABILITY_AVX512 + +#ifdef CPU_CAPABILITY_AVX512 + +template <> +class Vectorized : public Vectorizedi { +private: + static const Vectorized ones; +public: + using value_type = int64_t; + using size_type = int; + static constexpr size_type size() { + return 8; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int64_t v) { values = _mm512_set1_epi64(v); } + Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4, + int64_t val5, int64_t val6, int64_t val7, int64_t val8) { + values = _mm512_setr_epi64(val1, val2, val3, val4, + val5, val6, val7, val8); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi64(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mask_ = _mm512_cmp_epi64_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi64(mask_, a.values, b.values); + } + template + static Vectorized arange(int64_t base = 0, step_t step = static_cast(1)) { + return Vectorized(base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step); + } + static Vectorized + set(Vectorized a, Vectorized b, int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ int64_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < size(); ++i) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int64_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __at_align__ int64_t tmp_values[size()]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int64_t)); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + Vectorized abs() const { + auto is_larger_mask = _mm512_cmpgt_epi64_mask(zero_vector, values); + auto is_larger = _mm512_mask_set1_epi64(zero_vector, is_larger_mask, 0xFFFFFFFFFFFFFFFF); + auto inverse = _mm512_xor_si512(values, is_larger); + return _mm512_sub_epi64(inverse, is_larger); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi64(0); + } + Vectorized conj() const { + return *this; + } + Vectorized frac() const; + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +class Vectorized : public Vectorizedi { +private: + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; + static const Vectorized ones; +public: + using value_type = int32_t; + static constexpr int size() { + return 16; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int32_t v) { values = _mm512_set1_epi32(v); } + Vectorized(int32_t val1, int32_t val2, int32_t val3, int32_t val4, + int32_t val5, int32_t val6, int32_t val7, int32_t val8, + int32_t val9, int32_t val10, int32_t val11, int32_t val12, + int32_t val13, int32_t val14, int32_t val15, int32_t val16) { + values = _mm512_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8, + val9, val10, val11, val12, val13, val14, val15, val16); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi32(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi32(0xFFFFFFFF); + auto mask_ = _mm512_cmp_epi32_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi32(mask_, a.values, b.values); + } + template + static Vectorized arange(int32_t base = 0, step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, + base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, + base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step); + } + static Vectorized + set(Vectorized a, Vectorized b, int32_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int32_t count) { + __at_align__ int32_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < size(); ++i) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int32_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __at_align__ int32_t tmp_values[size()]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); + } + } + void dump() const { + for (size_t i = 0; i < size(); ++i) { + std::cout << (int)((value_type*)&values)[i] << " "; + } + std::cout << std::endl; + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi32(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi32(0); + } + Vectorized conj() const { + return *this; + } + Vectorized frac() const; + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +inline void convert(const int32_t *src, float *dst, int64_t n) { + int64_t i; + // int32_t and float have same size +#ifndef _MSC_VER +# pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + auto input_vec = _mm512_loadu_si512(reinterpret_cast(src + i)); + auto output_vec = _mm512_cvtepi32_ps(input_vec); + _mm512_storeu_ps(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +# pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t *src, double *dst, int64_t n) { + int64_t i; + // int32_t has half the size of double +#ifndef _MSC_VER +# pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + auto input_256_vec = _mm256_loadu_si256(reinterpret_cast(src + i)); + auto output_vec = _mm512_cvtepi32_pd(input_256_vec); + _mm512_storeu_pd(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +# pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +class Vectorized : public Vectorizedi { +private: + static const Vectorized ones; + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; +public: + using value_type = int16_t; + static constexpr int size() { + return 32; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int16_t v) { values = _mm512_set1_epi16(v); } + Vectorized(int16_t val1, int16_t val2, int16_t val3, int16_t val4, + int16_t val5, int16_t val6, int16_t val7, int16_t val8, + int16_t val9, int16_t val10, int16_t val11, int16_t val12, + int16_t val13, int16_t val14, int16_t val15, int16_t val16, + int16_t val17, int16_t val18, int16_t val19, int16_t val20, + int16_t val21, int16_t val22, int16_t val23, int16_t val24, + int16_t val25, int16_t val26, int16_t val27, int16_t val28, + int16_t val29, int16_t val30, int16_t val31, int16_t val32) { + values = _mm512_set_epi16(val32, val31, val30, val29, val28, val27, val26, val25, + val24, val23, val22, val21, val20, val19, val18, val17, + val16, val15, val14, val13, val12, val11, val10, val9, + val8, val7, val6, val5, val4, val3, val2, val1); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi16(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi16(0xFFFF); + auto mask_ = _mm512_cmp_epi16_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi16(mask_, a.values, b.values); + } + template + static Vectorized arange(int16_t base = 0, step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, + base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, + base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step, + base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step, + base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step, + base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step, + base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step + ); + } + static Vectorized + set(Vectorized a, Vectorized b, int16_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int16_t count) { + __at_align__ int16_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < size(); ++i) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __at_align__ int16_t tmp_values[size()]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi16(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized frac() const; + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +class Vectorized : public Vectorizedi { +private: + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; + static const Vectorized ones; +public: + using value_type = int8_t; + static constexpr int size() { + return 64; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int8_t v) { values = _mm512_set1_epi8(v); } + Vectorized(int8_t val1, int8_t val2, int8_t val3, int8_t val4, + int8_t val5, int8_t val6, int8_t val7, int8_t val8, + int8_t val9, int8_t val10, int8_t val11, int8_t val12, + int8_t val13, int8_t val14, int8_t val15, int8_t val16, + int8_t val17, int8_t val18, int8_t val19, int8_t val20, + int8_t val21, int8_t val22, int8_t val23, int8_t val24, + int8_t val25, int8_t val26, int8_t val27, int8_t val28, + int8_t val29, int8_t val30, int8_t val31, int8_t val32, + int8_t val33, int8_t val34, int8_t val35, int8_t val36, + int8_t val37, int8_t val38, int8_t val39, int8_t val40, + int8_t val41, int8_t val42, int8_t val43, int8_t val44, + int8_t val45, int8_t val46, int8_t val47, int8_t val48, + int8_t val49, int8_t val50, int8_t val51, int8_t val52, + int8_t val53, int8_t val54, int8_t val55, int8_t val56, + int8_t val57, int8_t val58, int8_t val59, int8_t val60, + int8_t val61, int8_t val62, int8_t val63, int8_t val64){ + values = _mm512_set_epi8(val64, val63, val62, val61, val60, val59, val58, val57, + val56, val55, val54, val53,val52, val51, val50, val49, + val48, val47, val46, val45, val44, val43, val42, val41, + val40, val39, val38, val37, val36, val35, val34, val33, + val32, val31, val30, val29, val28, val27, val26, val25, + val24, val23, val22, val21, val20, val19, val18, val17, + val16, val15, val14, val13, val12, val11, val10, val9, + val8, val7, val6, val5, val4, val3, val2, val1); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi8(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epi8_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, a.values, b.values); + } + template + static Vectorized arange(int8_t base = 0, step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, + base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, + base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step, + base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step, + base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step, + base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step, + base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step, + base + 32 * step, base + 33 * step, base + 34 * step, base + 35 * step, + base + 36 * step, base + 37 * step, base + 38 * step, base + 39 * step, + base + 40 * step, base + 41 * step, base + 42 * step, base + 43 * step, + base + 44 * step, base + 45 * step, base + 46 * step, base + 47 * step, + base + 48 * step, base + 49 * step, base + 50 * step, base + 51 * step, + base + 52 * step, base + 53 * step, base + 54 * step, base + 55 * step, + base + 56 * step, base + 57 * step, base + 58 * step, base + 59 * step, + base + 60 * step, base + 61 * step, base + 62 * step, base + 63 * step); + } + static Vectorized + set(Vectorized a, Vectorized b, int8_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + case 32: + return blend<0xFFFFFFFF>(a, b); + case 33: + return blend<0x1FFFFFFFF>(a, b); + case 34: + return blend<0x3FFFFFFFF>(a, b); + case 35: + return blend<0x7FFFFFFFF>(a, b); + case 36: + return blend<0xFFFFFFFFF>(a, b); + case 37: + return blend<0x1FFFFFFFFF>(a, b); + case 38: + return blend<0x3FFFFFFFFF>(a, b); + case 39: + return blend<0x7FFFFFFFFF>(a, b); + case 40: + return blend<0xFFFFFFFFFF>(a, b); + case 41: + return blend<0x1FFFFFFFFFF>(a, b); + case 42: + return blend<0x3FFFFFFFFFF>(a, b); + case 43: + return blend<0x7FFFFFFFFFF>(a, b); + case 44: + return blend<0xFFFFFFFFFFF>(a, b); + case 45: + return blend<0x1FFFFFFFFFFF>(a, b); + case 46: + return blend<0x3FFFFFFFFFFF>(a, b); + case 47: + return blend<0x7FFFFFFFFFFF>(a, b); + case 48: + return blend<0xFFFFFFFFFFFF>(a, b); + case 49: + return blend<0x1FFFFFFFFFFFF>(a, b); + case 50: + return blend<0x3FFFFFFFFFFFF>(a, b); + case 51: + return blend<0x7FFFFFFFFFFFF>(a, b); + case 52: + return blend<0xFFFFFFFFFFFFF>(a, b); + case 53: + return blend<0x1FFFFFFFFFFFFF>(a, b); + case 54: + return blend<0x3FFFFFFFFFFFFF>(a, b); + case 55: + return blend<0x7FFFFFFFFFFFFF>(a, b); + case 56: + return blend<0xFFFFFFFFFFFFFF>(a, b); + case 57: + return blend<0x1FFFFFFFFFFFFFF>(a, b); + case 58: + return blend<0x3FFFFFFFFFFFFFF>(a, b); + case 59: + return blend<0x7FFFFFFFFFFFFFF>(a, b); + case 60: + return blend<0xFFFFFFFFFFFFFFF>(a, b); + case 61: + return blend<0x1FFFFFFFFFFFFFFF>(a, b); + case 62: + return blend<0x3FFFFFFFFFFFFFFF>(a, b); + case 63: + return blend<0x7FFFFFFFFFFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int8_t count) { + __at_align__ int8_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (size_t i = 0; i < size(); ++i) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int8_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __at_align__ int8_t tmp_values[size()]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int8_t)); + } + } + const int8_t& operator[](int idx) const = delete; + int8_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi8(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi8(0); + } + Vectorized conj() const { + return *this; + } + Vectorized frac() const; + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_epi64(a, b); +} + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_epi32(a, b); +} + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_epi16(a, b); +} + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_epi8(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_epi64(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_epi32(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_epi16(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_epi8(a, b); +} + +// Negation. Defined here so we can utilize operator- +Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return _mm512_mullo_epi64(a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return _mm512_mullo_epi16(a, b); +} + +template +Vectorized inline int_elementwise_binary_512(const Vectorized& a, const Vectorized& b, Op op) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] = op(values_a[i], values_b[i]); + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + // We don't have an instruction for multiplying int8_t + return int_elementwise_binary_512(a, b, std::multiplies()); +} + +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return _mm512_min_epi64(a, b); +} + +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return _mm512_min_epi32(a, b); +} + +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return _mm512_min_epi16(a, b); +} + +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return _mm512_min_epi8(a, b); +} + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return _mm512_max_epi64(a, b); +} + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return _mm512_max_epi32(a, b); +} + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return _mm512_max_epi16(a, b); +} + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return _mm512_max_epi8(a, b); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { + return _mm512_min_epi64(max_val, _mm512_max_epi64(a, min_val)); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { + return _mm512_min_epi32(max_val, _mm512_max_epi32(a, min_val)); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { + return _mm512_min_epi16(max_val, _mm512_max_epi16(a, min_val)); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { + return _mm512_min_epi8(max_val, _mm512_max_epi8(a, min_val)); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { + return _mm512_min_epi64(max_val, a); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { + return _mm512_min_epi32(max_val, a); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { + return _mm512_min_epi16(max_val, a); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { + return _mm512_min_epi8(max_val, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { + return _mm512_max_epi64(min_val, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { + return _mm512_max_epi32(min_val, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { + return _mm512_max_epi16(min_val, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { + return _mm512_max_epi8(min_val, a); +} + +template +Vectorized inline convert_to_int32(const T* ptr) { + return Vectorized::loadu(ptr); +} + +template<> +Vectorized inline convert_to_int32(const int8_t* ptr) { + return _mm512_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast(ptr))); +} + +template<> +Vectorized inline convert_to_int32(const uint8_t* ptr) { + return _mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast(ptr))); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} + +template>::value, int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_si512(a, b); +} +template>::value, int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_si512(a, b); +} +template>::value, int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_si512(a, b); +} +template>::value, int> = 0> +inline Vectorized operator~(const Vectorized& a) { + return _mm512_xor_si512(a, _mm512_set1_epi32(-1)); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h new file mode 100644 index 00000000000..5b5ac195f3c --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h @@ -0,0 +1,1195 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#include +#include + +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over Vectorized::float_num_vecs +// iterations. + +namespace at { +namespace vec { +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +struct Vectorizedqi { + protected: + __m512i vals __attribute__((aligned(64))); + + public: + Vectorizedqi() {} + Vectorizedqi(__m512i v) : vals(v) {} + operator __m512i() const { + return vals; + } +}; + + +template +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + T min_val, + T max_val); + +template <> +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + int32_t min_val, + int32_t max_val) { + // This function is for linkage only, will not be used + AT_ERROR("pack_saturate_and_clamp is not supported"); +} + +template <> +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + int8_t min_val, + int8_t max_val) { + __m512i packed_and_sat = _mm512_packs_epi16(first, second); + return _mm512_max_epi8( + _mm512_set1_epi8(min_val), + _mm512_min_epi8(packed_and_sat, _mm512_set1_epi8(max_val))); +} + +template <> +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + uint8_t min_val, + uint8_t max_val) { + __m512i packed_and_sat = _mm512_packus_epi16(first, second); + return _mm512_max_epu8( + _mm512_set1_epi8(min_val), + _mm512_min_epu8(packed_and_sat, _mm512_set1_epi8(max_val))); +} + + +template +inline void __attribute__((always_inline)) QuantizeAvx512( + const float* src, + typename T::underlying* dst, + int len, + float inverse_scale, + int64_t zero_point) { + constexpr int VLEN = 16; + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + const __m512i min_v = _mm512_set1_epi32(min_val); + const __m512i max_v = _mm512_set1_epi32(max_val); + // This is the largest int32 value < int32_max exactly representable in float + constexpr int32_t int32_float_max_val = + std::numeric_limits::max() - 127; + int i = 0; + __m512 inverse_scale_v = _mm512_set1_ps(inverse_scale); + // clang-format off + static const __m512i shuffle_mask_v = _mm512_set_epi8( + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00); + // clang-format on + __m512i permute_mask_v = + _mm512_set_epi32(0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02, + 0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00); + __m512i permute_mask_l8_v = + _mm512_set_epi32(0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0c, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); + int len_aligned = len / (VLEN * 4) * (VLEN * 4); + for (; i < len_aligned; i += 4 * VLEN) { + // x + __m512 x_vals = _mm512_load_ps(src + i); + __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v); + // If the floating point value is greater than int32_max, + // _mm512_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to + // Clip at int32_float_max_val to avoid this. + x_transformed_v = + _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // y + __m512 y_vals = _mm512_load_ps(src + i + VLEN); + __m512 y_transformed_v = _mm512_mul_ps(y_vals, inverse_scale_v); + y_transformed_v = + _mm512_min_ps(y_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // z + __m512 z_vals = _mm512_load_ps(src + i + 2 * VLEN); + __m512 z_transformed_v = _mm512_mul_ps(z_vals, inverse_scale_v); + z_transformed_v = + _mm512_min_ps(z_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // w + __m512 w_vals = _mm512_load_ps(src + i + 3 * VLEN); + __m512 w_transformed_v = _mm512_mul_ps(w_vals, inverse_scale_v); + w_transformed_v = + _mm512_min_ps(w_transformed_v, _mm512_set1_ps(int32_float_max_val)); + + __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v); + __m512i y_rounded_v = _mm512_cvtps_epi32(y_transformed_v); + __m512i z_rounded_v = _mm512_cvtps_epi32(z_transformed_v); + __m512i w_rounded_v = _mm512_cvtps_epi32(w_transformed_v); + + // add zero point + x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point)); + y_rounded_v = _mm512_add_epi32(y_rounded_v, _mm512_set1_epi32(zero_point)); + z_rounded_v = _mm512_add_epi32(z_rounded_v, _mm512_set1_epi32(zero_point)); + w_rounded_v = _mm512_add_epi32(w_rounded_v, _mm512_set1_epi32(zero_point)); + + __m512i xy_packed_v = _mm512_packs_epi32(x_rounded_v, y_rounded_v); + __m512i zw_packed_v = _mm512_packs_epi32(z_rounded_v, w_rounded_v); + __m512i xyzw_clamped_v = pack_saturate_and_clamp( + xy_packed_v, zw_packed_v, min_val, max_val); + + xyzw_clamped_v = + _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i), xyzw_clamped_v); + } + + // Additional 8-lane AVX512 version to take advantage when len is smaller + // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM) + for (; i < len / VLEN * VLEN; i += VLEN) { + __m512 x_vals = _mm512_load_ps(src + i); + __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v); + x_transformed_v = + _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val)); + __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v); + x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point)); + __m512i x_clipped_v = + _mm512_max_epi32(min_v, _mm512_min_epi32(max_v, x_rounded_v)); + + x_clipped_v = _mm512_shuffle_epi8(x_clipped_v, shuffle_mask_v); + x_clipped_v = _mm512_permutexvar_epi32(permute_mask_l8_v, x_clipped_v); + _mm_storeu_si128( + reinterpret_cast<__m128i*>(dst + i), + _mm512_castsi512_si128(x_clipped_v)); + } + + for (; i < len; ++i) { + float transformed = src[i] * inverse_scale; + + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + transformed = zero_point + nearbyint(transformed); + float clipped = + std::min(std::max(transformed, float(min_val)), float(max_val)); + dst[i] = clipped; + } +} + +template<> +struct Vectorized : public Vectorizedqi { + using size_type = int; + static constexpr size_type size() { + return 16; + } + + static constexpr int float_num_vecs() { + return 1; + } + + static constexpr int int_num_vecs() { + return 1; + } + + using float_vec_return_type = std::array, 1>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m512i vals_) { vals = vals_;} + + // Broadcast constructor + Vectorized(const c10::qint32& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi32(uw); + } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + __m512 float_vals = _mm512_cvtepi32_ps(vals); + return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vectorized retval; + auto rhs_data = (__m512)rhs[0]; + at::native::quantize_vec( + scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 16); + return retval; + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epi32(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epi32(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epi32( + _mm512_max_epi32(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {_mm512_sub_epi32(vals, b)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + + __m512 scaled = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier_v); + __m512i rounded = _mm512_cvtps_epi32(scaled); + return _mm512_add_epi32(rounded, zero_point_v); + } + + void dump() const { + for (size_t i = 0; i < 16; ++i) { + std::cout << ((int32_t*)&vals)[i] << " "; + } + std::cout << std::endl; + } + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi32(a, b); +} + +/* + * Convert values from int32 back to int8/uint8 + */ +template +__m512i RequantizeAvx512( + const std::array, 4>& inp, + __m512 multiplier, + __m512i zp) { + static_assert( + std::is_same::value || std::is_same::value, + "Only int8_t/uint8_t are supported"); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m512i permute_mask_v = + _mm512_set_epi32(0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02, + 0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00); + __m512 x_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier); + __m512 y_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[1]), multiplier); + __m512 z_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[2]), multiplier); + __m512 w_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[3]), multiplier); + + __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v); + __m512i y_rounded_v = _mm512_cvtps_epi32(y_scaled_v); + __m512i z_rounded_v = _mm512_cvtps_epi32(z_scaled_v); + __m512i w_rounded_v = _mm512_cvtps_epi32(w_scaled_v); + + /* Add zero point */ + __m512i x_v = _mm512_add_epi32(x_rounded_v, zp); + __m512i y_v = _mm512_add_epi32(y_rounded_v, zp); + __m512i z_v = _mm512_add_epi32(z_rounded_v, zp); + __m512i w_v = _mm512_add_epi32(w_rounded_v, zp); + + /* Pack to int16_t and saturate */ + __m512i xy_packed_v = _mm512_packs_epi32(x_v, y_v); + __m512i zw_packed_v = _mm512_packs_epi32(z_v, w_v); + + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + /* + * xyzw_clamped_v has results in the following layout so we need to + * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 x8-11 y8-11 z8-11 w8-11 x12-15 y12-15 z12-15 w12-15 + */ + xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); + return xyzw_clamped_v; +} + +template<> +struct Vectorized : public Vectorizedqi { + static constexpr int size() { + return 64; + } + + static constexpr int float_num_vecs() { + return 4; + } + + static constexpr int int_num_vecs() { + return 4; + } + + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::qint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + + Vectorized() {} + Vectorized(__m512i vals_) { vals = vals_;} + + // Broadcast constructor + Vectorized(const c10::qint8& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi8(uw); + } + + // This is needed because the compiler emits awful code for the default + // constructor for moving the enum + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) { } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + private: + __m512i cvtepi8_epi32(__m128i epi8_vals) const { + return _mm512_cvtepi8_epi32(epi8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_neg_zp_premul) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_neg_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul); + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + int8_t quantized_values[64]; + QuantizeAvx512( + rhs_data, quantized_values, 64, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epi8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epi8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epi8( + _mm512_max_epi8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512i int32_val0 = cvtepi8_epi32(int_val0); + __m512i int32_val1 = cvtepi8_epi32(int_val1); + __m512i int32_val2 = cvtepi8_epi32(int_val2); + __m512i int32_val3 = cvtepi8_epi32(int_val3); + + __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]); + + __m512i int32_b0 = cvtepi8_epi32(int_b0); + __m512i int32_b1 = cvtepi8_epi32(int_b1); + __m512i int32_b2 = cvtepi8_epi32(int_b2); + __m512i int32_b3 = cvtepi8_epi32(int_b3); + + __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0); + __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1); + __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2); + __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3); + + return {Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + return RequantizeAvx512(inp, multiplier_v, zero_point_v); + } + + void dump() const { + for (size_t i = 0; i < size(); ++i) { + std::cout << (int)((value_type*)&vals)[i] << " "; + } + std::cout << std::endl; + } + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template<> +struct Vectorized : public Vectorizedqi { + static constexpr int size() { + return 64; + } + + static constexpr int float_num_vecs() { + return 4; + } + + static constexpr int int_num_vecs() { + return 4; + } + + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::quint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m512i vals_) { vals = vals_;} + + // Broadcast constructor + Vectorized(const c10::quint8& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi8(uw); + } + + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) { } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + private: + __m512i cvtepu8_epi32(__m128i epu8_vals) const { + return _mm512_cvtepu8_epi32(epu8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul); + + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + uint8_t quantized_values[64]; + QuantizeAvx512( + rhs_data, quantized_values, 64, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epu8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epu8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epu8( + _mm512_max_epu8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512i int32_val0 = cvtepu8_epi32(int_val0); + __m512i int32_val1 = cvtepu8_epi32(int_val1); + __m512i int32_val2 = cvtepu8_epi32(int_val2); + __m512i int32_val3 = cvtepu8_epi32(int_val3); + + __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]); + + __m512i int32_b0 = cvtepu8_epi32(int_b0); + __m512i int32_b1 = cvtepu8_epi32(int_b1); + __m512i int32_b2 = cvtepu8_epi32(int_b2); + __m512i int32_b3 = cvtepu8_epi32(int_b3); + + __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0); + __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1); + __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2); + __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3); + return {Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + return RequantizeAvx512(inp, multiplier_v, zero_point_v); + } + + void dump() const { + for (size_t i = 0; i < size(); ++i) { + std::cout << (int)((value_type*)&vals)[i] << " "; + } + std::cout << std::endl; + } + private: + + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +#else + +// NOTE: These are low-performance implementations that we fall back on. + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + static constexpr int size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size() / 8; + } + + static constexpr int int_num_vecs() { + return size() / 8; + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (size_t i = 0; i < size(); ++i) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + float_vec_return_type rv; + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[16]; + for (int j = 0; j < 16; ++j) { + tmp_vals[j] = at::native::dequantize_val( + scale[j], zero_point[j], T(vals[16 * i + j])); + } + rv[i] = Vectorized(tmp_vals[0], + tmp_vals[1], + tmp_vals[2], + tmp_vals[3], + tmp_vals[4], + tmp_vals[5], + tmp_vals[6], + tmp_vals[7], + tmp_vals[8], + tmp_vals[9], + tmp_vals[10], + tmp_vals[11], + tmp_vals[12], + tmp_vals[13], + tmp_vals[14], + tmp_vals[15]); + } + return rv; + } + + void dump() const { + for (int i = 0; i < size(); ++i) { + std::cout << vals[i] << " "; + } + std::cout << std::endl; + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>() {} + Vectorized(c10::qint32 val) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (size_t i = 0; i < size(); ++i) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = + nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>() {} + Vectorized(c10::qint8 val) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64> { + Vectorized() + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>() {} + Vectorized(c10::quint8 val) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +#endif // defined(CPU_CAPABILITY_AVX512) && !defined(MSVC) + +}}} diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_base.h b/aten/src/ATen/cpu/vec/vec_base.h similarity index 84% rename from aten/src/ATen/cpu/vec/vec256/vec256_base.h rename to aten/src/ATen/cpu/vec/vec_base.h index 596dac67c2c..da5f318bf53 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include #include @@ -32,13 +32,28 @@ #include #include +// These macros helped us unify vec_base.h +#ifdef CPU_CAPABILITY_AVX512 #if defined(__GNUC__) -#define __at_align32__ __attribute__((aligned(32))) +#define __at_align__ __attribute__((aligned(64))) #elif defined(_WIN32) -#define __at_align32__ __declspec(align(32)) +#define __at_align__ __declspec(align(64)) #else -#define __at_align32__ +#define __at_align__ #endif +#define VECTOR_WIDTH 64 +#define int_vector __m512i +#else // CPU_CAPABILITY_AVX512 +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(32))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(32)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 32 +#define int_vector __m256i +#endif // CPU_CAPABILITY_AVX512 namespace at { namespace vec { @@ -70,11 +85,11 @@ using int_same_size_t = typename int_of_size::type; // NOTE: If you specialize on a type, you must define all operations! -// emulates vectorized types +// emulates Vectorized types template struct Vectorized { private: - __at_align32__ T values[32 / sizeof(T)]; + __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; public: using value_type = T; using size_type = int; @@ -111,7 +126,7 @@ public: // identifier is odr-used or not, and in any case it's hard to tell if // a variable is odr-used or not. So best to just cut the problem at the root. static constexpr size_type size() { - return 32 / sizeof(T); + return VECTOR_WIDTH / sizeof(T); } Vectorized() : values{0} {} Vectorized(T val) { @@ -134,60 +149,60 @@ public: template static Vectorized blend(const Vectorized& a, const Vectorized& b) { int64_t mask = mask_; - Vectorized vec; + Vectorized vector; for (int64_t i = 0; i < size(); i++) { if (mask & 0x01) { - vec[i] = b[i]; + vector[i] = b[i]; } else { - vec[i] = a[i]; + vector[i] = a[i]; } mask = mask >> 1; } - return vec; + return vector; } static Vectorized blendv(const Vectorized& a, const Vectorized& b, const Vectorized& mask) { - Vectorized vec; + Vectorized vector; int_same_size_t buffer[size()]; mask.store(buffer); for (int64_t i = 0; i < size(); i++) { if (buffer[i] & 0x01) { - vec[i] = b[i]; + vector[i] = b[i]; } else { - vec[i] = a[i]; + vector[i] = a[i]; } } - return vec; + return vector; } template // step sometimes requires a higher precision type (e.g., T=int, step_t=double) static Vectorized arange(T base = static_cast(0), step_t step = static_cast(1)) { - Vectorized vec; + Vectorized vector; for (int64_t i = 0; i < size(); i++) { - vec.values[i] = base + i * step; + vector.values[i] = base + i * step; } - return vec; + return vector; } static Vectorized set(const Vectorized& a, const Vectorized& b, int64_t count = size()) { - Vectorized vec; + Vectorized vector; for (int64_t i = 0; i < size(); i++) { if (i < count) { - vec[i] = b[i]; + vector[i] = b[i]; } else { - vec[i] = a[i]; + vector[i] = a[i]; } } - return vec; + return vector; } static Vectorized loadu(const void* ptr) { - Vectorized vec; - std::memcpy(vec.values, ptr, 32); - return vec; + Vectorized vector; + std::memcpy(vector.values, ptr, VECTOR_WIDTH); + return vector; } static Vectorized loadu(const void* ptr, int64_t count) { - Vectorized vec; - std::memcpy(vec.values, ptr, count * sizeof(T)); - return vec; + Vectorized vector; + std::memcpy(vector.values, ptr, count * sizeof(T)); + return vector; } void store(void* ptr, int count = size()) const { std::memcpy(ptr, values, count * sizeof(T)); @@ -203,15 +218,15 @@ public: return mask; } Vectorized isnan() const { - Vectorized vec; + Vectorized vector; for (int64_t i = 0; i != size(); i++) { if (_isnan(values[i])) { - std::memset(static_cast(vec.values + i), 0xFF, sizeof(T)); + std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); } else { - std::memset(static_cast(vec.values + i), 0, sizeof(T)); + std::memset(static_cast(vector.values + i), 0, sizeof(T)); } } - return vec; + return vector; } Vectorized map(T (*const f)(T)) const { Vectorized ret; @@ -488,15 +503,15 @@ private: template inline Vectorized binary_pred(const Vectorized& other, Op op) const { // All bits are set to 1 if the pred is true, otherwise 0. - Vectorized vec; + Vectorized vector; for (int64_t i = 0; i != size(); i++) { if (op(values[i], other.values[i])) { - std::memset(static_cast(vec.values + i), 0xFF, sizeof(T)); + std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); } else { - std::memset(static_cast(vec.values + i), 0, sizeof(T)); + std::memset(static_cast(vector.values + i), 0, sizeof(T)); } } - return vec; + return vector; } public: @@ -511,11 +526,11 @@ private: template inline Vectorized binary_pred_bool(const Vectorized& other, Op op) const { // 1 if the pred is true, otherwise 0. - Vectorized vec; + Vectorized vector; for (int i = 0; i != size(); ++ i) { - vec[i] = bool(op(values[i], other.values[i])); + vector[i] = bool(op(values[i], other.values[i])); } - return vec; + return vector; } public: @@ -668,41 +683,62 @@ Vectorized inline clamp_min(const Vectorized &a, const Vectorized &min_ struct Vectorizedi; -#ifdef CPU_CAPABILITY_AVX2 - +#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) template static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { - __m256i buffer; - __m256i a_buffer = _mm256_loadu_si256(reinterpret_cast((const T*)a)); - __m256i b_buffer = _mm256_loadu_si256(reinterpret_cast((const T*)b)); + int_vector buffer; +#if defined(CPU_CAPABILITY_AVX2) + int_vector a_buffer = _mm256_load_si256(reinterpret_cast((const T*)a)); + int_vector b_buffer = _mm256_load_si256(reinterpret_cast((const T*)b)); +#elif defined(CPU_CAPABILITY_AVX512) + int_vector a_buffer = _mm512_load_si512(reinterpret_cast((const T*)a)); + int_vector b_buffer = _mm512_load_si512(reinterpret_cast((const T*)b)); +#endif buffer = op(a_buffer, b_buffer); - __at_align32__ T results[Vectorized::size()]; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(results), buffer); + __at_align__ T results[Vectorized::size()]; + +#if defined(CPU_CAPABILITY_AVX2) + _mm256_store_si256(reinterpret_cast(results), buffer); +#elif defined(CPU_CAPABILITY_AVX512) + _mm512_store_si512(reinterpret_cast(results), buffer); +#endif return Vectorized::loadu(results); } template>::value, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { - // We enclose _mm256_and_si256 with lambda because it is always_inline - return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_and_si256(a, b); }); + // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); +#endif } template>::value, int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { - // We enclose _mm256_or_si256 with lambda because it is always_inline - return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_or_si256(a, b); }); + // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); +#endif } template>::value, int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { - // We enclose _mm256_xor_si256 with lambda because it is always_inline - return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_xor_si256(a, b); }); + // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); +#endif } #else template static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { - static constexpr uint32_t element_no = 32 / sizeof(intmax_t); - __at_align32__ intmax_t buffer[element_no]; + static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); + __at_align__ intmax_t buffer[element_no]; const intmax_t *a_ptr = reinterpret_cast((const T*) a); const intmax_t *b_ptr = reinterpret_cast((const T*) b); for (uint32_t i = 0U; i < element_no; ++ i) { @@ -724,12 +760,12 @@ inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_xor()); } -#endif +#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) template>::value, int> = 0> inline Vectorized operator~(const Vectorized& a) { Vectorized ones; // All bits are 1 - memset((T*) ones, 0xFF, 32); + memset((T*) ones, 0xFF, VECTOR_WIDTH); return a ^ ones; } @@ -802,7 +838,9 @@ inline mask_gather(const Vectorized& src, T const* base_addr, } // Cast a given vector to another type without changing the bits representation. -// So a Vec of 256 bits containing all ones can be cast to a +// So a Vectorized of 512 bits containing all ones can be cast to a +// Vectorized of 512 bits containing all ones (i.e., eight negative 1s). +// A Vec of 256 bits containing all ones can be cast to a // Vec of 256 bits containing all ones (i.e., four negative 1s). namespace { // There is a struct here because we don't have static_if and I can't @@ -840,10 +878,16 @@ inline Vectorized> convert_to_int_of_same_size(const Vectoriz return Vectorized>::loadu(static_cast(buffer)); } -// E.g., inputs: a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} -// b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} -// returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} -// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// Example inputs for AVX512: +// a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} +// b Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} +// returns: +// Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} +// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} +// Example inputs for AVX2: a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} +// b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} +// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} template inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>> deinterleave2(const Vectorized& a, const Vectorized& b) { @@ -866,8 +910,14 @@ deinterleave2(const Vectorized& a, const Vectorized& b) { } // inverse operation of deinterleave2 -// E.g., inputs: a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} -// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// Example inputs for AVX512: +// a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} +// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} +// returns, for AVX512: +// Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} +// Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} +// Example inputs for AVX2 : a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} +// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} // returns: Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} // Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} template diff --git a/aten/src/ATen/cpu/vml.h b/aten/src/ATen/cpu/vml.h index 08dca99de04..b9cc47f3fe7 100644 --- a/aten/src/ATen/cpu/vml.h +++ b/aten/src/ATen/cpu/vml.h @@ -35,21 +35,8 @@ #include #endif -// [Note SSE-AVX transitions] -// There is a bug in Glibc2.23 -// https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280. Calling zeroall -// when using AVX/AVX2 code resolves this. -#if defined(CPU_CAPABILITY_AVX) && defined(__GLIBC__) && __GLIBC_MINOR__ == 23 -#define DL_RUNTIME_BUG(op, type_) \ - using value_t = typename c10::scalar_value_type::type;\ - volatile value_t x = (value_t)(1); \ - x = std::op(x); \ - _mm256_zeroall(); -#define DL_RUNTIME_BUG_BFLOAT16() _mm256_zeroall(); -#else #define DL_RUNTIME_BUG(op, type_) #define DL_RUNTIME_BUG_BFLOAT16() -#endif namespace at { namespace vml { @@ -117,36 +104,36 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) { }); \ } -IMPLEMENT_VML_BUG(abs) -IMPLEMENT_VML_BUG(acos) -IMPLEMENT_VML_BUG(asin) -IMPLEMENT_VML_BUG(atan) -IMPLEMENT_VML_BUG(ceil) -IMPLEMENT_VML_BUG(cos) +IMPLEMENT_VML(abs) +IMPLEMENT_VML(acos) +IMPLEMENT_VML(asin) +IMPLEMENT_VML(atan) +IMPLEMENT_VML(ceil) +IMPLEMENT_VML(cos) // IMPLEMENT_VML_BUG(cosh) -IMPLEMENT_VML_BUG(erf) -IMPLEMENT_VML_BUG(erfc) +IMPLEMENT_VML(erf) +IMPLEMENT_VML(erfc) IMPLEMENT_VML(erfinv) -IMPLEMENT_VML_BUG(exp) -IMPLEMENT_VML_BUG(expm1) -IMPLEMENT_VML_BUG(floor) +IMPLEMENT_VML(exp) +IMPLEMENT_VML(expm1) +IMPLEMENT_VML(floor) IMPLEMENT_VML(i0) IMPLEMENT_VML(i0e) IMPLEMENT_VML(reciprocal) -IMPLEMENT_VML_BUG(log) -IMPLEMENT_VML_BUG(log10) -IMPLEMENT_VML_BUG(log1p) -IMPLEMENT_VML_BUG(log2) +IMPLEMENT_VML(log) +IMPLEMENT_VML(log10) +IMPLEMENT_VML(log1p) +IMPLEMENT_VML(log2) IMPLEMENT_VML(neg) -IMPLEMENT_VML_BUG(sin) +IMPLEMENT_VML(sin) // IMPLEMENT_VML_BUG(sinh) -IMPLEMENT_VML_BUG(sqrt) -IMPLEMENT_VML_BUG(round) +IMPLEMENT_VML(sqrt) +IMPLEMENT_VML(round) IMPLEMENT_VML(rsqrt) -IMPLEMENT_VML_BUG(tan) -IMPLEMENT_VML_BUG(tanh) -IMPLEMENT_VML_BUG(trunc) -IMPLEMENT_VML_BUG(lgamma) +IMPLEMENT_VML(tan) +IMPLEMENT_VML(tanh) +IMPLEMENT_VML(trunc) +IMPLEMENT_VML(lgamma) #if AT_MKL_ENABLED() && !defined(__APPLE__) diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index b801f819f3e..16b3d1b2874 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -952,91 +952,109 @@ void lu_solve_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(cholesky_stub, DEFAULT, &cholesky_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(cholesky_stub, &cholesky_kernel); +REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(eig_stub, &eig_kernel_impl); +REGISTER_AVX512_DISPATCH(eig_stub, &eig_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(eig_stub, &eig_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(geqrf_stub, &geqrf_kernel); +REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(ormqr_stub, &ormqr_kernel); +REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(lstsq_stub, &lstsq_kernel); +REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel); -REGISTER_AVX_DISPATCH(lu_stub, &lu_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_AVX512_DISPATCH(lu_stub, &lu_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel); -REGISTER_AVX_DISPATCH(lu_solve_stub, &lu_solve_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel); }} // namespace at::native diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 7e3666ef984..ada5ed5ee75 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -16,12 +16,12 @@ static CPUCapability compute_cpu_capability() { return CPUCapability::VSX; } #else + if (strcmp(envar, "avx512") == 0) { + return CPUCapability::AVX512; + } if (strcmp(envar, "avx2") == 0) { return CPUCapability::AVX2; } - if (strcmp(envar, "avx") == 0) { - return CPUCapability::AVX; - } #endif if (strcmp(envar, "default") == 0) { return CPUCapability::DEFAULT; @@ -31,12 +31,13 @@ static CPUCapability compute_cpu_capability() { #if !defined(__powerpc__) && !defined(__s390x__) if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && \ + cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_fma3()) { + return CPUCapability::AVX512; + } if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) { return CPUCapability::AVX2; } - if (cpuinfo_has_x86_avx()) { - return CPUCapability::AVX; - } } #endif #ifdef HAVE_VSX_CPU_DEFINITION @@ -54,8 +55,8 @@ CPUCapability get_cpu_capability() { void* DispatchStubImpl::get_call_ptr( DeviceType device_type , void *DEFAULT -#ifdef HAVE_AVX_CPU_DEFINITION - , void *AVX +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 @@ -72,8 +73,8 @@ void* DispatchStubImpl::get_call_ptr( if (!fptr) { fptr = choose_cpu_impl( DEFAULT -#ifdef HAVE_AVX_CPU_DEFINITION - , AVX +#ifdef HAVE_AVX512_CPU_DEFINITION + , AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , AVX2 @@ -102,8 +103,8 @@ void* DispatchStubImpl::get_call_ptr( void* DispatchStubImpl::choose_cpu_impl( void *DEFAULT -#ifdef HAVE_AVX_CPU_DEFINITION - , void *AVX +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 @@ -114,18 +115,26 @@ void* DispatchStubImpl::choose_cpu_impl( ) { auto capability = static_cast(get_cpu_capability()); (void)capability; +#ifdef HAVE_AVX512_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::AVX512)) { + // Quantization kernels have also been disabled on Windows + // for AVX512 because some of their tests are flaky on Windows. + // Ideally, we should have AVX512 kernels for all kernels. + if (C10_UNLIKELY(!AVX512)) { + // dispatch to AVX2, since the AVX512 kernel is missing + TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel"); + return AVX2; + } else { + return AVX512; + } + } +#endif #ifdef HAVE_AVX2_CPU_DEFINITION if (capability >= static_cast(CPUCapability::AVX2)) { TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel"); return AVX2; } #endif -#ifdef HAVE_AVX_CPU_DEFINITION - if (capability >= static_cast(CPUCapability::AVX)) { - TORCH_INTERNAL_ASSERT(AVX, "DispatchStub: missing AVX kernel"); - return AVX; - } -#endif #ifdef HAVE_VSX_CPU_DEFINITION if (capability >= static_cast(CPUCapability::VSX)) { TORCH_INTERNAL_ASSERT(VSX, "DispatchStub: missing VSX kernel"); diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 315f5007dbd..94a2dc421a6 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -9,8 +9,8 @@ // Implements instruction set specific function dispatch. // -// Kernels that may make use of specialized instruction sets (e.g. AVX) are -// compiled multiple times with different compiler flags (e.g. -mavx). A +// Kernels that may make use of specialized instruction sets (e.g. AVX2) are +// compiled multiple times with different compiler flags (e.g. -mavx2). A // DispatchStub contains a table of function pointers for a kernel. At runtime, // the fastest available kernel is chosen based on the features reported by // cpuinfo. @@ -50,8 +50,8 @@ enum class CPUCapability { #ifdef HAVE_VSX_CPU_DEFINITION VSX = 1, #else - AVX = 1, - AVX2 = 2, + AVX2 = 1, + AVX512 = 2, #endif NUM_OPTIONS }; @@ -71,8 +71,8 @@ struct TORCH_API DispatchStubImpl { void* get_call_ptr( DeviceType device_type , void *DEFAULT -#ifdef HAVE_AVX_CPU_DEFINITION - , void *AVX +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 @@ -89,8 +89,8 @@ struct TORCH_API DispatchStubImpl { */ void* choose_cpu_impl( void *DEFAULT -#ifdef HAVE_AVX_CPU_DEFINITION - , void *AVX +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 @@ -126,8 +126,8 @@ private: return reinterpret_cast( impl.get_call_ptr(device_type , reinterpret_cast(DEFAULT) -#ifdef HAVE_AVX_CPU_DEFINITION - , reinterpret_cast(AVX) +#ifdef HAVE_AVX512_CPU_DEFINITION + , reinterpret_cast(AVX512) #endif #ifdef HAVE_AVX2_CPU_DEFINITION , reinterpret_cast(AVX2) @@ -155,8 +155,8 @@ public: } static FnPtr DEFAULT; -#ifdef HAVE_AVX_CPU_DEFINITION - static FnPtr AVX; +#ifdef HAVE_AVX512_CPU_DEFINITION + static FnPtr AVX512; #endif #ifdef HAVE_AVX2_CPU_DEFINITION static FnPtr AVX2; @@ -203,10 +203,10 @@ struct RegisterHIPDispatch { #define REGISTER_ARCH_DISPATCH(name, arch, fn) \ template <> decltype(fn) DispatchStub::arch = fn; -#ifdef HAVE_AVX_CPU_DEFINITION -#define REGISTER_AVX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX, fn) +#ifdef HAVE_AVX512_CPU_DEFINITION +#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn) #else -#define REGISTER_AVX_DISPATCH(name, fn) +#define REGISTER_AVX512_DISPATCH(name, fn) #endif #ifdef HAVE_AVX2_CPU_DEFINITION @@ -223,8 +223,8 @@ struct RegisterHIPDispatch { #define REGISTER_NO_CPU_DISPATCH(name, fn_type) \ REGISTER_ARCH_DISPATCH(name, DEFAULT, static_cast(nullptr)) \ - REGISTER_AVX_DISPATCH(name, static_cast(nullptr)) \ - REGISTER_AVX2_DISPATCH(name, static_cast(nullptr)) \ + REGISTER_AVX512_DISPATCH(name, static_cast(nullptr)) \ + REGISTER_AVX2_DISPATCH(name, static_cast(nullptr)) \ REGISTER_VSX_DISPATCH(name, static_cast(nullptr)) #define REGISTER_CUDA_DISPATCH(name, fn) \ @@ -244,6 +244,8 @@ struct RegisterHIPDispatch { // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn) #elif defined(CPU_CAPABILITY) #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) +#define REGISTER_NO_AVX512_DISPATCH(name, fn_type) \ + REGISTER_AVX512_DISPATCH(name, static_cast(nullptr)) #endif diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index b8b9405528c..ff9529d32dd 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -275,10 +275,10 @@ REGISTER_ARCH_DISPATCH( DEFAULT, &_segment_reduce_cpu_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_AVX512_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); // Currently some computation is being duplicated across forward and backward. @@ -319,7 +319,7 @@ REGISTER_ARCH_DISPATCH( DEFAULT, &_segment_reduce_cpu_backward_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH( +REGISTER_AVX512_DISPATCH( _segment_reduce_backward_stub, &_segment_reduce_cpu_backward_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/aten/src/ATen/native/cpu/README.md b/aten/src/ATen/native/cpu/README.md index abf83b560ba..b3fbd171ad4 100644 --- a/aten/src/ATen/native/cpu/README.md +++ b/aten/src/ATen/native/cpu/README.md @@ -4,7 +4,7 @@ The most important things to know: compiled multiple times for different instruction sets.** Yes, this folder is named `cpu`, but that doesn't mean put any old CPU kernel it. Only put CPU kernels which need to be compiled -multiple times to take advantage of AVX/SSE instructions, but +multiple times to take advantage of AVX512/AVX2/SSE instructions, but only on processors that support them. **Ensure that all implementations in this folder are put in an @@ -52,14 +52,14 @@ All of the `*.cpp` files in this folder will be compiled under all compiler flags specified by `CPU_CAPABILITY_FLAGS` in `aten/src/ATen/CMakeLists.txt`. The purpose of this is to allow the compilation with various compiler -flags to enable features such as AVX instructions, while using runtime -dispatch, which makes sure only valid instructions will be used on any +flags to enable features such as AVX2 or AVX512 instructions, while using +runtime dispatch, which makes sure only valid instructions will be used on any given platform. -Vectorized.h provides a generic implementation of a vec type that allows +vec.h provides a generic implementation of vec type that allows the programmer to write code packing various primitives (such as floats) -within 256bit registers. vec defines various operators such as + and * -and provides functions to allow operations such as max, min, etc. +within 256bit & 512bits registers. vec defines various operators such as ++ and * and provides functions to allow operations such as max, min, etc. As an example `ReduceOpsKernel.cpp` implements a generic `kernel_` that reduces an entire array using a given associative binary operation such as +. @@ -74,5 +74,5 @@ generic code, which will be compiled under multipled compilation settings. `../ReduceOps.cpp` now includes the header `ReduceOpsKernel.h`, which contains a generic definition of `sumImplAll`. This function allows the user to reduce over a dimension or all dimensions. The appropiate capability is chosen at -runtime using cpuinfo. If the current platform has AVX, `sumImpl` will be set -to `sumImplAll`. +runtime using cpuinfo. If the current platform has AVX2, `sumImpl` will be set +to `sumImplAll`. diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index 30a7dc64a3a..d97e81f4367 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -32,7 +32,8 @@ static inline bool is_outer_reduction(const int64_t* strides) { } template -static inline void reduction128(char** data, int64_t n, int64_t stride, func_t op, vec_func_t vop, bool reduce) { +static inline void vectorized_reduction(char** data, int64_t n, int64_t stride, + func_t op, vec_func_t vop, bool reduce) { VEC_LOOP_HEADER(func_t, data) const char* in1_ptr = data[1]; Vec acc[4]; @@ -80,7 +81,7 @@ static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); int64_t count = n / (4 * Vec::size()); if (count > 0) { - reduction128(data, count, vector_stride, op, vop, /*reduce=*/true); + vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true); } char* ptrs[3] = { data[0], data[0], data[1] }; int64_t strides[] = { 0, 0, sizeof(scalar_t) }; @@ -92,10 +93,14 @@ template static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { VEC_LOOP_HEADER(func_t, data) - // reduce down each column of 4 * Vec::size() elements (128 bytes) + // reduce down each column of 4 * Vec::size() elements (128 or 256 bytes) +#if defined(CPU_CAPABILITY_AVX512) + int64_t outer_stride[2] = { 256, 256 }; +#else int64_t outer_stride[2] = { 128, 128 }; +#endif UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] { - reduction128(data, size0, inner_stride, op, vop, /*reduce=*/false); + vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false); }); // reduce down the remaining columns diff --git a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp index 3fdde8c07f1..6318543e7eb 100644 --- a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp +++ b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp @@ -219,9 +219,15 @@ inline void _vec_softmax( int64_t outer_stride = dim_size * dim_stride; int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1); int vectorized_step = Vec().size(); // Currently, we only support scalar_t with double or float32 - TORCH_CHECK( +#ifdef CPU_CAPABILITY_AVX512 + TORCH_INTERNAL_ASSERT( + (vectorized_step == 16) || (vectorized_step == 8), + "vectorized_step must be 16 with dtype float or 8 with dtype double"); +#else + TORCH_INTERNAL_ASSERT( (vectorized_step == 8) || (vectorized_step == 4), "vectorized_step must be 8 with dtype float or 4 with dtype double"); +#endif parallel_for( 0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) { int64_t idx = begin; diff --git a/aten/src/ATen/native/cpu/SumKernel.cpp b/aten/src/ATen/native/cpu/SumKernel.cpp index 4ae4b3585f6..73c1c40b3c3 100644 --- a/aten/src/ATen/native/cpu/SumKernel.cpp +++ b/aten/src/ATen/native/cpu/SumKernel.cpp @@ -611,8 +611,15 @@ void nansum_kernel_impl(TensorIterator &iter) { } // namespace (anonymous) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(sum_stub, &sum_kernel_impl); + +// nansum on Float16 has poor accuracy with AVX2, and more so with AVX512. +// So until it's fixed, it won't be dispatched with AVX512. GH issue 59415. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +#ifndef CPU_CAPABILITY_AVX512 REGISTER_DISPATCH(nansum_stub, &nansum_kernel_impl); +#else +REGISTER_NO_AVX512_DISPATCH(nansum_stub, reduce_fn); +#endif }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index b0b3507e4d0..c8775031dfe 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -715,8 +715,15 @@ REGISTER_DISPATCH(exponential_stub, &CPU_CAPABILITY::exponential_kernel); REGISTER_DISPATCH(geometric_stub, &CPU_CAPABILITY::geometric_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(log_normal_stub, &CPU_CAPABILITY::log_normal_kernel); +#ifdef CPU_CAPABILITY_AVX512 +// normal_stub isn't being dispatched to AVX512 because it exposes +// flakiness in test_sgd of test/test_optim.py +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(normal_stub, void(*)(Tensor&, const double, const double, c10::optional)); +#else // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(normal_stub, &CPU_CAPABILITY::normal_kernel); +#endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(uniform_stub, &CPU_CAPABILITY::uniform_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/aten/src/ATen/native/cpu/avx_mathfun.h b/aten/src/ATen/native/cpu/avx_mathfun.h index 33f8569e6e8..080cd833d3a 100644 --- a/aten/src/ATen/native/cpu/avx_mathfun.h +++ b/aten/src/ATen/native/cpu/avx_mathfun.h @@ -32,26 +32,17 @@ #include -/* yes I know, the top of this file is quite ugly */ +/* The original source of this file has been modified. */ +#if defined(CPU_CAPABILITY_AVX2) + #if defined(__GNUC__) # define ALIGN32_BEG __attribute__((aligned(32))) #elif defined(_WIN32) # define ALIGN32_BEG __declspec(align(32)) #endif -/* __m128 is ugly to write */ -typedef __m256 v8sf; // vector of 8 float (avx) -typedef __m256i v8si; // vector of 8 int (avx) -typedef __m128i v4si; // vector of 8 int (avx) - -#define _PI32AVX_CONST(Name, Val) \ - static const ALIGN32_BEG int _pi32avx_##Name[4] = { Val, Val, Val, Val } - -_PI32AVX_CONST(1, 1); -_PI32AVX_CONST(inv1, ~1); -_PI32AVX_CONST(2, 2); -_PI32AVX_CONST(4, 4); - +typedef __m256 v8sf; // vector of 8 float (avx2) +typedef __m256i v8si; // vector of 8 int (avx2) /* declare some AVX constants -- why can't I figure a better way to do that? */ #define _PS256_CONST(Name, Val) \ @@ -91,67 +82,6 @@ _PS256_CONST(cephes_log_p8, + 3.3333331174E-1); _PS256_CONST(cephes_log_q1, -2.12194440e-4); _PS256_CONST(cephes_log_q2, 0.693359375); -#ifndef CPU_CAPABILITY_AVX2 - -typedef union imm_xmm_union { - v8si imm; - v4si xmm[2]; -} imm_xmm_union; - -#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) { \ - imm_xmm_union u __attribute__((aligned(32))); \ - u.imm = imm_; \ - xmm0_ = u.xmm[0]; \ - xmm1_ = u.xmm[1]; \ -} - -#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) { \ - imm_xmm_union u __attribute__((aligned(32))); \ - u.xmm[0]=xmm0_; u.xmm[1]=xmm1_; imm_ = u.imm; \ - } - - -#define AVX2_BITOP_USING_SSE2(fn) \ -static inline v8si _mm256_##fn(v8si x, int a) \ -{ \ - /* use SSE2 instruction to perform the bitop AVX2 */ \ - v4si x1, x2; \ - v8si ret; \ - COPY_IMM_TO_XMM(x, x1, x2); \ - x1 = _mm_##fn(x1,a); \ - x2 = _mm_##fn(x2,a); \ - COPY_XMM_TO_IMM(x1, x2, ret); \ - return(ret); \ -} - -#warning "Using SSE2 to perform AVX2 bitshift ops" -AVX2_BITOP_USING_SSE2(slli_epi32) -AVX2_BITOP_USING_SSE2(srli_epi32) - -#define AVX2_INTOP_USING_SSE2(fn) \ -static inline v8si _mm256_##fn(v8si x, v8si y) \ -{ \ - /* use SSE2 instructions to perform the AVX2 integer operation */ \ - v4si x1, x2; \ - v4si y1, y2; \ - v8si ret; \ - COPY_IMM_TO_XMM(x, x1, x2); \ - COPY_IMM_TO_XMM(y, y1, y2); \ - x1 = _mm_##fn(x1,y1); \ - x2 = _mm_##fn(x2,y2); \ - COPY_XMM_TO_IMM(x1, x2, ret); \ - return(ret); \ -} - -#warning "Using SSE2 to perform AVX2 integer ops" -AVX2_INTOP_USING_SSE2(and_si128) -AVX2_INTOP_USING_SSE2(andnot_si128) -AVX2_INTOP_USING_SSE2(cmpeq_epi32) -AVX2_INTOP_USING_SSE2(sub_epi32) -AVX2_INTOP_USING_SSE2(add_epi32) - -#endif /* CPU_CAPABILITY_AVX2 */ - /* natural logarithm computed for 8 simultaneous float return NaN for x <= 0 @@ -326,11 +256,6 @@ inline v8sf sin256_ps(v8sf x) { // any x v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y; v8si imm0, imm2; -#ifndef CPU_CAPABILITY_AVX2 - v4si imm0_1, imm0_2; - v4si imm2_1, imm2_2; -#endif - sign_bit = x; /* take the absolute value */ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); @@ -346,7 +271,6 @@ inline v8sf sin256_ps(v8sf x) { // any x If we don't have AVX, let's perform them using SSE2 directives */ -#ifdef CPU_CAPABILITY_AVX2 /* store the integer part of y in mm0 */ imm2 = _mm256_cvttps_epi32(y); /* j=(j+1) & (~1) (see the cephes sources) */ @@ -366,35 +290,6 @@ inline v8sf sin256_ps(v8sf x) { // any x */ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2); imm2 = _mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0); -#else - /* we use SSE2 routines to perform the integer ops */ - COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2); - - imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1); - imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1); - - COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2); - y = _mm256_cvtepi32_ps(imm2); - - imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4); - imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4); - - imm0_1 = _mm_slli_epi32(imm0_1, 29); - imm0_2 = _mm_slli_epi32(imm0_2, 29); - - COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2); - - imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128()); - imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128()); - - COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2); -#endif v8sf swap_sign_bit = _mm256_castsi256_ps(imm0); v8sf poly_mask = _mm256_castsi256_ps(imm2); @@ -453,18 +348,12 @@ inline v8sf cos256_ps(v8sf x) { // any x v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y; v8si imm0, imm2; -#ifndef CPU_CAPABILITY_AVX2 - v4si imm0_1, imm0_2; - v4si imm2_1, imm2_2; -#endif - /* take the absolute value */ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); /* scale by 4/Pi */ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI); -#ifdef CPU_CAPABILITY_AVX2 /* store the integer part of y in mm0 */ imm2 = _mm256_cvttps_epi32(y); /* j=(j+1) & (~1) (see the cephes sources) */ @@ -479,39 +368,6 @@ inline v8sf cos256_ps(v8sf x) { // any x /* get the polynom selection mask */ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2); imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0); -#else - - /* we use SSE2 routines to perform the integer ops */ - COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2); - - imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1); - imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1); - - COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2); - y = _mm256_cvtepi32_ps(imm2); - - imm2_1 = _mm_sub_epi32(imm2_1, *(v4si*)_pi32avx_2); - imm2_2 = _mm_sub_epi32(imm2_2, *(v4si*)_pi32avx_2); - - imm0_1 = _mm_andnot_si128(imm2_1, *(v4si*)_pi32avx_4); - imm0_2 = _mm_andnot_si128(imm2_2, *(v4si*)_pi32avx_4); - - imm0_1 = _mm_slli_epi32(imm0_1, 29); - imm0_2 = _mm_slli_epi32(imm0_2, 29); - - COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2); - - imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128()); - imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128()); - - COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2); -#endif v8sf sign_bit = _mm256_castsi256_ps(imm0); v8sf poly_mask = _mm256_castsi256_ps(imm2); @@ -571,12 +427,6 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y; v8si imm0, imm2, imm4; -#ifndef CPU_CAPABILITY_AVX2 - v4si imm0_1, imm0_2; - v4si imm2_1, imm2_2; - v4si imm4_1, imm4_2; -#endif - sign_bit_sin = x; /* take the absolute value */ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); @@ -586,7 +436,6 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { /* scale by 4/Pi */ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI); -#ifdef CPU_CAPABILITY_AVX2 /* store the integer part of y in imm2 */ imm2 = _mm256_cvttps_epi32(y); @@ -606,38 +455,7 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2); imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0); //v8sf poly_mask = _mm256_castsi256_ps(imm2); -#else - /* we use SSE2 routines to perform the integer ops */ - COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2); - imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1); - imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1); - - COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2); - y = _mm256_cvtepi32_ps(imm2); - - imm4_1 = imm2_1; - imm4_2 = imm2_2; - - imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4); - imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4); - - imm0_1 = _mm_slli_epi32(imm0_1, 29); - imm0_2 = _mm_slli_epi32(imm0_2, 29); - - COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2); - - imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128()); - imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128()); - - COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2); -#endif v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0); v8sf poly_mask = _mm256_castsi256_ps(imm2); @@ -653,22 +471,9 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { x = _mm256_add_ps(x, xmm2); x = _mm256_add_ps(x, xmm3); -#ifdef CPU_CAPABILITY_AVX2 imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2); imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4); imm4 = _mm256_slli_epi32(imm4, 29); -#else - imm4_1 = _mm_sub_epi32(imm4_1, *(v4si*)_pi32avx_2); - imm4_2 = _mm_sub_epi32(imm4_2, *(v4si*)_pi32avx_2); - - imm4_1 = _mm_andnot_si128(imm4_1, *(v4si*)_pi32avx_4); - imm4_2 = _mm_andnot_si128(imm4_2, *(v4si*)_pi32avx_4); - - imm4_1 = _mm_slli_epi32(imm4_1, 29); - imm4_2 = _mm_slli_epi32(imm4_2, 29); - - COPY_XMM_TO_IMM(imm4_1, imm4_2, imm4); -#endif v8sf sign_bit_cos = _mm256_castsi256_ps(imm4); @@ -713,3 +518,5 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { *s = _mm256_xor_ps(xmm1, sign_bit_sin); *c = _mm256_xor_ps(xmm2, sign_bit_cos); } + +#endif // CPU_CAPABILITY_AVX2 diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index 7c3635ed54f..f146ffdc2ca 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -149,8 +149,8 @@ static void _fft_fill_with_conjugate_symmetry_cpu_( // Register this one implementation for all cpu types instead of compiling multiple times REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_) -REGISTER_AVX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) +REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) // _out variants can be shared between PocketFFT and MKL Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 045bf828472..b7c63a247ea 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -197,7 +198,28 @@ int64_t hsum(const uint8_t* A, int len) { for (const auto k : c10::irange(8)) { row_sum += temp[k]; } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + __m512i sum_v = _mm512_setzero_si512(); + __m512i one_epi16_v = _mm512_set1_epi16(1); + __m512i one_epi8_v = _mm512_set1_epi8(1); + // vectorized + for (; i < len / 64 * 64; i += 64) { + __m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i)); + sum_v = _mm512_add_epi32( + sum_v, + _mm512_madd_epi16( + // first argument is unsigned, second is signed + _mm512_maddubs_epi16(src_v, one_epi8_v), + one_epi16_v) + ); + } + + alignas(64) int32_t temp[16]; + _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v); + for (const auto k : c10::irange(16)) { + row_sum += temp[k]; + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -233,7 +255,28 @@ int64_t hsum(const int8_t* A, int len) { for (const auto k : c10::irange(8)) { row_sum += temp[k]; } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + __m512i sum_v = _mm512_setzero_si512(); + __m512i one_epi16_v = _mm512_set1_epi16(1); + __m512i one_epi8_v = _mm512_set1_epi8(1); + // vectorized + for (; i < len / 64 * 64; i += 64) { + __m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i)); + sum_v = _mm512_add_epi32( + sum_v, + _mm512_madd_epi16( + // first argument is unsigned, second is signed + _mm512_maddubs_epi16(one_epi8_v, src_v), + one_epi16_v) + ); + } + + alignas(64) int32_t temp[16]; + _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v); + for (const auto k : c10::irange(16)) { + row_sum += temp[k]; + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -255,7 +298,7 @@ int64_t hsum(const int32_t* A, int len) { __m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i)); // widen __m128i src_lo_epi32 = _mm256_castsi256_si128(src_epi32); - __m128i src_hi_epi32 = _mm256_extractf128_si256(src_epi32, 1); + __m128i src_hi_epi32 = _mm256_extracti128_si256(src_epi32, 1); __m256i src_lo_epi64 = _mm256_cvtepi32_epi64(src_lo_epi32); __m256i src_hi_epi64 = _mm256_cvtepi32_epi64(src_hi_epi32); // add @@ -268,7 +311,27 @@ int64_t hsum(const int32_t* A, int len) { for (const auto k : c10::irange(4)) { row_sum += temp[k]; } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + __m512i sum_epi64 = _mm512_setzero_si512(); + // vectorized + for (; i < len / 16 * 16; i += 16) { + __m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i)); + // widen + __m256i src_lo_epi32 = _mm512_castsi512_si256(src_epi32); + __m256i src_hi_epi32 = _mm512_extracti32x8_epi32(src_epi32, 1); + __m512i src_lo_epi64 = _mm512_cvtepi32_epi64(src_lo_epi32); + __m512i src_hi_epi64 = _mm512_cvtepi32_epi64(src_hi_epi32); + // add + sum_epi64 = _mm512_add_epi64(sum_epi64, src_lo_epi64); + sum_epi64 = _mm512_add_epi64(sum_epi64, src_hi_epi64); + } + + alignas(64) int64_t temp[8]; + _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_epi64); + for (const auto k : c10::irange(8)) { + row_sum += temp[k]; + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -313,7 +376,36 @@ int64_t hsum_sq(const uint8_t* A, int len) { } sum_v_epu32 = _mm256_setzero_si256(); } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + __m512i sum_v_epu32 = _mm512_setzero_si512(); + alignas(64) int32_t temp[16]; + int overflow_threshold = 262144; // 2147483647(max of int32)/(512*512)*8 = 262144 + int loop = len / overflow_threshold + 1; + for(int j=0; j<=loop; j++){ + for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) { + // (i31, ..., i0) + __m256i src_epu8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i)); + __m512i src_epu16 = _mm512_cvtepu8_epi16(src_epu8); + // (i31 ^ 2, ..., i0 ^ 2) + __m512i sq_epu16 = _mm512_mullo_epi16(src_epu16, src_epu16); + // (i15 ^ 2, ..., i0 ^ 2) + __m256i sq_lo_epu16 = _mm512_castsi512_si256(sq_epu16); + // (i31 ^ 2, ..., i16 ^ 2) + __m256i sq_hi_epu16 = _mm512_extracti32x8_epi32(sq_epu16, 1); + // widen to epu32 + __m512i sq_lo_epu32 = _mm512_cvtepu16_epi32(sq_lo_epu16); + __m512i sq_hi_epu32 = _mm512_cvtepu16_epi32(sq_hi_epu16); + // add to running sum + sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_lo_epu32); + sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_hi_epu32); + } + _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epu32); + for (const auto k : c10::irange(16)) { + row_sum += temp[k]; + } + sum_v_epu32 = _mm512_setzero_si512(); + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -361,7 +453,40 @@ int64_t hsum_sq(const int8_t* A, int len) { } sum_v_epi32 = _mm256_setzero_si256(); } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + // vectorized + __m512i sum_v_epi32 = _mm512_setzero_si512(); + alignas(64) int32_t temp[16]; + + int overflow_threshold = 1048576; //2147483647/(256*256)*8 = 1048576 + int loop = len / overflow_threshold + 1; + + for(int j=0; j<=loop; j++){ + for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) { + // (i31, ..., i0) + __m256i src_epi8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i)); + __m512i src_epi16 = _mm512_cvtepi8_epi16(src_epi8); + // (i31 ^ 2, ..., i0 ^ 2) + __m512i sq_epi16 = _mm512_mullo_epi16(src_epi16, src_epi16); + // (i15 ^ 2, ..., i0 ^ 2) + __m256i sq_lo_epi16 = _mm512_castsi512_si256(sq_epi16); + // (i31 ^ 2, ..., i16 ^ 2) + __m256i sq_hi_epi16 = _mm512_extracti32x8_epi32(sq_epi16, 1); + // widen to epi32 + __m512i sq_lo_epi32 = _mm512_cvtepi16_epi32(sq_lo_epi16); + __m512i sq_hi_epi32 = _mm512_cvtepi16_epi32(sq_hi_epi16); + // add to running sum + sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_lo_epi32); + sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_hi_epi32); + } + _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epi32); + + for (const auto k : c10::irange(16)) { + row_sum += temp[k]; + } + sum_v_epi32 = _mm512_setzero_si512(); + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -391,7 +516,21 @@ float hsum_sq(const int32_t* A, int len) { for (const auto k : c10::irange(8)) { row_sum += static_cast(temp[k]); } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + __m512 sum_ps = _mm512_setzero_ps(); + // vectorized + for (; i < len / 16 * 16; i += 16) { + __m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i)); + __m512 src_ps = _mm512_cvtepi32_ps(src_epi32); + sum_ps = _mm512_add_ps(sum_ps, _mm512_mul_ps(src_ps, src_ps)); + } + + alignas(64) float temp[16]; + _mm512_store_ps(temp, sum_ps); + for (const auto k : c10::irange(16)) { + row_sum += static_cast(temp[k]); + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -1239,7 +1378,7 @@ void qmaxpool_2d_nhwc_kernel( } template -void do_avg_pool_nhwc_on_AVX2( +void do_avg_pool_nhwc_on_AVX_n( const typename T::underlying* i_p, typename T::underlying* o_p, int& c_start, @@ -1256,17 +1395,25 @@ void do_avg_pool_nhwc_on_AVX2( int hsize, int wsize, int csize) { -#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) // buffer for channel accumulator, used to interchange channel-loop // to inner-most, so that memory access of the input tensor data is // continuous. +#ifdef CPU_CAPABILITY_AVX2 constexpr int cb_size = 16; +#else + constexpr int cb_size = 8; +#endif constexpr int vec_width = Vectorized::size() / 4; constexpr int cb_step = cb_size * vec_width; Vectorized acc_buffer[cb_size]; Vectorized acc_buffer_fp[cb_size]; +#ifdef CPU_CAPABILITY_AVX2 if (vec_width == 8) { +#else + if (vec_width == 16) { +#endif for (int c = c_start; c < csize; c += cb_step) { int cend = std::min(cb_size, (csize - c) / vec_width); // initialize loop @@ -1292,14 +1439,23 @@ void do_avg_pool_nhwc_on_AVX2( // convert int32 accumulative to fp32 vec::convert((int*)acc_buffer, (float*)acc_buffer_fp, cend * vec_width); - // first quantize using AVX using 32 lanes, then 8, finally falls + // first quantize using AVX2 or AVX512 using 32 lanes, then 8, finally falls // back to single +#ifdef CPU_CAPABILITY_AVX2 QuantizeAvx2( (float*)acc_buffer_fp, o_p + c, cend * vec_width, multiplier, output_zero_point); +#else + QuantizeAvx512( + (float*)acc_buffer_fp, + o_p + c, + cend * vec_width, + multiplier, + output_zero_point); +#endif } c_start = csize / vec_width * vec_width; } @@ -1307,7 +1463,7 @@ void do_avg_pool_nhwc_on_AVX2( } template -void do_avg_pool_on_AVX2( +void do_avg_pool_on_AVX_n( typename T::underlying* i_p, typename T::underlying* o_p, int64_t& c, @@ -1326,9 +1482,13 @@ void do_avg_pool_on_AVX2( int64_t stride_D, int64_t stride_H, int64_t stride_W) { -#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) - constexpr auto vec_width = Vectorized::size() / 4; +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) + constexpr int vec_width = Vectorized::size() / 4; +#ifdef CPU_CAPABILITY_AVX2 if (vec_width == 8) { +#else + if (vec_width == 16) { +#endif for (; c + vec_width <= channel_size; c += vec_width) { int64_t tcntr = 0; @@ -1416,10 +1576,10 @@ void _qadaptive_avg_pool_kernel( istartH * istrideH + istartW * istrideW; - // Note: If AVX is not available, `do_avg_pool_on_AVX2 is a noop. + // Note: If AVX is not available, `do_avg_pool_on_AVX_n is a noop. // In that case, the following loop takes over // TODO: more vectorization with loop interleaving - do_avg_pool_on_AVX2( + do_avg_pool_on_AVX_n( internal_i_p, o_p, c, @@ -1438,7 +1598,6 @@ void _qadaptive_avg_pool_kernel( istrideD, istrideH, istrideW); - // 1) The following loop handles the remaining channels // 2) It also handles the Non-AVX2 path for (; c < sizeC; ++c) { @@ -1610,7 +1769,7 @@ void _qavg_pool_nhwc_kernel( // For int8 quantization, we implicitly use int32 as accumulation // Or else, it will go to the slow path // TODO: support 16bit, 32bit, and etc. - do_avg_pool_nhwc_on_AVX2( + do_avg_pool_nhwc_on_AVX_n( i_p, o_p, c_start, @@ -1744,7 +1903,7 @@ void qavg_pool3d_nhwc_kernel( } template -int64_t do_quantized_bilinear_on_AVX2( +int64_t do_quantized_bilinear_on_AVX_n( const typename T::underlying*& pos1, typename T::underlying*& pos2, int64_t input_height, @@ -1762,9 +1921,13 @@ int64_t do_quantized_bilinear_on_AVX2( const int64_t h1p, const int64_t w1p) { int64_t c = 0; -#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) constexpr auto vec_width = Vectorized::size() / 4; +#ifdef CPU_CAPABILITY_AVX2 if (vec_width == 8) { +#else + if (vec_width == 16) { +#endif for (; c + vec_width <= channels; c += vec_width) { Vectorized pos1_fp_v[4]; Vectorized pos1_int_v[4]; @@ -1861,7 +2024,7 @@ void qupsample_bilinear2d_nhwc_kernel( o_p + (h2 * output_width + w2) * channels; // We have to isolate this function out because the VS does not // expand the macro correctly. - c = do_quantized_bilinear_on_AVX2( + c = do_quantized_bilinear_on_AVX_n( pos1, pos2, input_height, @@ -1989,7 +2152,7 @@ void q_batch_norm_kernel( reinterpret_cast(input.data_ptr()); scalar_t::underlying* Y = reinterpret_cast(output.data_ptr()); - constexpr int kVLen = 8; + constexpr int kVLen = Vectorized::size(); const int64_t outer_size = N * HxW; using Vec = Vectorized; // Hoisted variables @@ -2292,7 +2455,7 @@ void quantized_normalize_kernel( float y_scale = Y->q_scale(); float y_inv_scale = 1.0f / y_scale; - constexpr int kFloatVLen = 8; + constexpr int kFloatVLen = fVec::size(); int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs(); int64_t kNumIntVecInLayer = N / kIntVLen; int64_t kNonVecRemInLayer = N % kIntVLen; @@ -3095,6 +3258,114 @@ void dequantize_tensor_per_tensor_affine_sub_byte_cpu( } // namespace +// Some quantization tests are flaky on Windows with AVX512. If --continue-through-error +// is used, only one fails. But if the failing test is skipped, another one fails. +// If the second test is also skipped, a third one fails. +// So, until Quantization support for Windows is fixed for AVX512, +// AVX2 kernels would be used instead. Ref: GH 56992. +#if defined(CPU_CAPABILITY_AVX512) && defined(_WIN32) +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_channel_affine_stub, + dequantize_tensor_per_channel_affine_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_stub, + dequantize_tensor_per_tensor_affine_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub, + dequantize_tensor_per_channel_float_qparams_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(fake_quant_grad_learnable_tensor_stub, + fake_quant_learnable_grad_tensor_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(fake_quant_per_channel_cachemask_stub, + fake_quant_per_channel_cachemask_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(fake_quant_tensor_cachemask_stub, + fake_quant_tensor_cachemask_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub, + fake_quant_tensor_cachemask_tensor_qparams_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadaptive_avg_pool2d_nhwc_stub, + qadaptive_avg_pool2d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub, + qadaptive_avg_pool3d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadd_relu_stub, qbinary_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadd_scalar_relu_stub, qadd_scalar_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadd_scalar_stub, qadd_scalar_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadd_stub, qbinary_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qavg_pool2d_nhwc_stub, qavg_pool2d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qavg_pool3d_nhwc_stub, qavg_pool3d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qbatch_norm_relu_stub, qbatch_norm_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qbatch_norm_stub, qbatch_norm_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qcat_nhwc_stub, qcat_nhwc_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qcat_relu_nhwc_stub, qcat_nhwc_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qclamp_stub, qclamp_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qclamp_min_stub, qclamp_minmax_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qclamp_max_stub, qclamp_minmax_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qelu_stub, qelu_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qhardsigmoid_stub, qhardsigmoid_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qhardswish_stub, qhardswish_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qmaxpool_2d_nhwc_stub, qmaxpool_2d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qmul_relu_stub, qbinary_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qmul_stub, qbinary_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qrelu6_stub, qrelu_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qrelu_leaky_stub, qrelu_leaky_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qrelu_stub, qrelu_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qsigmoid_stub, qsigmoid_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qtanh_stub, qtanh_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qthreshold_stub, qthreshold_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qtopk_stub, qtopk_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(fake_quant_grad_learnable_channel_stub, + fake_quant_learnable_per_channel_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_stub, + quantize_tensor_per_tensor_affine_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_affine_stub, + quantize_tensor_per_channel_affine_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_float_qparams_stub, + quantize_tensor_per_channel_float_qparams_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(quantized_normalize_stub, qnormalize_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qupsample_bilinear2d_nhwc_stub, qupsample_bilinear2d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_sub_byte_stub, + quantize_tensor_per_tensor_affine_sub_byte_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_sub_byte_stub, + dequantize_tensor_per_tensor_affine_sub_byte_fn); +#else // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub, &dequantize_tensor_per_channel_affine_cpu); @@ -3174,7 +3445,8 @@ REGISTER_DISPATCH(qthreshold_stub, &qthreshold_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub, &fake_quantize_learnable_channel_grad_kernel_cpu); +REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub, + &fake_quantize_learnable_channel_grad_kernel_cpu); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH( quantize_tensor_per_tensor_affine_stub, @@ -3200,7 +3472,7 @@ REGISTER_DISPATCH( REGISTER_DISPATCH( dequantize_tensor_per_tensor_affine_sub_byte_stub, &dequantize_tensor_per_tensor_affine_sub_byte_cpu); - +#endif // CPU_CAPABILITY_AVX512 && _WIN32 } // namespace native } // namespace at diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index 77bc02334db..4ee0596da6e 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -1071,13 +1071,17 @@ namespace { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(ComplexTests, TestComplexFloatImagRealConj) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float aa[] = { 1.5488e-28,2.5488e-28,3.5488e-28,4.5488e-28,5.5488e-28,6.5488e-28,7.5488e-28,8.5488e-28 }; + float aa[] = { 1.5488e-28,2.5488e-28,3.5488e-28,4.5488e-28,5.5488e-28,6.5488e-28,7.5488e-28,8.5488e-28, + 9.5488e-28,10.5488e-28,11.5488e-28,12.5488e-28,13.5488e-28,14.5488e-28,15.5488e-28,16.5488e-28}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float exp[] = { aa[0],0,aa[2],0,aa[4],0,aa[6],0 }; + float exp[] = { aa[0],0,aa[2],0,aa[4],0,aa[6],0,aa[8],0,aa[10],0,aa[12],0,aa[14],0 }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float exp3[] = { aa[1],0,aa[3],0,aa[5],0,aa[7],0 }; + float exp3[] = { aa[1],0,aa[3],0,aa[5],0,aa[7],0,aa[9],0,aa[11],0,aa[13],0,aa[15],0 }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float exp4[] = { 1.5488e-28, -2.5488e-28,3.5488e-28,-4.5488e-28,5.5488e-28,-6.5488e-28,7.5488e-28,-8.5488e-28 }; + float exp4[] = { 1.5488e-28, -2.5488e-28,3.5488e-28,-4.5488e-28, + 5.5488e-28,-6.5488e-28,7.5488e-28,-8.5488e-28, + 9.5488e-28,-10.5488e-28,11.5488e-28,-12.5488e-28, + 13.5488e-28,-14.5488e-28,15.5488e-28,-16.5488e-28 }; auto a = vcomplex::loadu(aa); auto actual1 = a.real(); auto actual3 = a.imag(); @@ -1304,6 +1308,7 @@ namespace { }, test_case); } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TYPED_TEST(FunctionalTests, Map) { using vec = TypeParam; using VT = ValueType; @@ -1339,15 +1344,16 @@ namespace { at::vec::map3([](vec x1, vec x2, vec x3) { return x1 + x2 + x3; }, y, x1, x2, x3, N); for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i]; } cmp(y, ref_y); - // test map3: y = x1 + x2 + x3 + x4 + // test map4: y = x1 + x2 + x3 + x4 at::vec::map4([](vec x1, vec x2, vec x3, vec x4) { return x1 + x2 + x3 + x4; }, y, x1, x2, x3, x4, N); for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i] + x4[i]; } cmp(y, ref_y); } - TYPED_TEST(FunctionalBF16Tests, Reduce) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) + TYPED_TEST(FunctionalBF16Tests, Reduce) { using vec = TypeParam; // Can't use ValueType here: - // Vectorized::value_type returns uint16_t on AVX2 + // Vectorized::value_type returns uint16_t on AVX2/AVX512 using VT = c10::BFloat16; using RT = float; // reference constexpr auto R = 2LL; // residual @@ -1394,7 +1400,6 @@ namespace { auto y2 = at::vec::map_reduce_all([](auto x) { return x - x.exp(); }, sum, x_b1, len); ASSERT_TRUE(cmp(y1, y2)) << "Failure Details:\nTest Seed to reproduce: " << seed << "\nmap_reduce_all, Length: " << len << "; fp32: " << y1 << "; bf16: " << RT(y2); - } // Map2ReduceAll for (int64_t len = 1; len <= N; len++) { diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index 0c9496e166c..8b0854866a9 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -13,7 +13,13 @@ #include #include #include + +#if defined(CPU_CAPABILITY_AVX512) +#define CACHE_LINE 64 +#else #define CACHE_LINE 32 +#endif + #if defined(__GNUC__) #define CACHE_ALIGN __attribute__((aligned(CACHE_LINE))) #define not_inline __attribute__((noinline)) @@ -26,7 +32,7 @@ CACHE_ALIGN #define #endif #if defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) #define TEST_AGAINST_DEFAULT 1 -#elif !defined(CPU_CAPABILITY_AVX) && !defined(CPU_CAPABILITY_AVX2) && !defined(CPU_CAPABILITY_VSX) +#elif !defined(CPU_CAPABILITY_AVX512) && !defined(CPU_CAPABILITY_AVX2) && !defined(CPU_CAPABILITY_VSX) #define TEST_AGAINST_DEFAULT 1 #else #undef TEST_AGAINST_DEFAULT @@ -41,7 +47,8 @@ CACHE_ALIGN #define return __VA_ARGS__(std::forward(args)...); \ } -#if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) && (defined(__GNUC__) || defined(__GNUG__)) +#if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) || \ + defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__)) #undef CHECK_DEQUANT_WITH_LOW_PRECISION #define CHECK_WITH_FMA 1 #elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 17b1a688c5e..429821496b3 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -722,44 +722,43 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS}) endif() -# NOTE [ Linking AVX and non-AVX files ] +# NOTE [ Linking AVX-n and non-AVX-n files ] # -# Regardless of the CPU capabilities, we build some files with AVX and AVX2 +# Regardless of the CPU capabilities, we build some files with AVX2, and AVX512 # instruction set. If the host CPU doesn't support those, we simply ignore their # functions at runtime during dispatch. # # We must make sure that those files are at the end of the input list when # linking the torch_cpu library. Otherwise, the following error scenario might # occur: -# 1. A non-AVX and an AVX file both call a function defined with the `inline` +# 1. A non-AVX2 and an AVX2 file both call a function defined with the `inline` # keyword # 2. The compiler decides not to inline this function # 3. Two different versions of the machine code are generated for this function: -# one without AVX instructions and one with AVX. -# 4. When linking, the AVX version is found earlier in the input object files, +# one without AVX2 instructions and one with AVX2. +# 4. When linking, the AVX2 version is found earlier in the input object files, # so the linker makes the entire library use it, even in code not guarded by # the dispatcher. -# 5. A CPU without AVX support executes this function, encounters an AVX +# 5. A CPU without AVX2 support executes this function, encounters an AVX2 # instruction and crashes. # # Thus we organize the input files in the following order: -# 1. All files with no AVX support -# 2. All files with AVX support (conveniently, they all have names ending with -# 'AVX.cpp') -# 3. All files with AVX2 support ('*AVX2.cpp') +# 1. All files with no AVX-n support +# 2. All files with AVX2 support ('*AVX2.cpp') +# 3. All files with AVX512 support ('*AVX512.cpp') set(Caffe2_CPU_SRCS_NON_AVX) -set(Caffe2_CPU_SRCS_AVX) set(Caffe2_CPU_SRCS_AVX2) +set(Caffe2_CPU_SRCS_AVX512) foreach(input_filename ${Caffe2_CPU_SRCS}) - if(${input_filename} MATCHES "AVX\\.cpp") - list(APPEND Caffe2_CPU_SRCS_AVX ${input_filename}) - elseif(${input_filename} MATCHES "AVX2\\.cpp") + if(${input_filename} MATCHES "AVX2\\.cpp") list(APPEND Caffe2_CPU_SRCS_AVX2 ${input_filename}) + elseif(${input_filename} MATCHES "AVX512\\.cpp") + list(APPEND Caffe2_CPU_SRCS_AVX512 ${input_filename}) else() list(APPEND Caffe2_CPU_SRCS_NON_AVX ${input_filename}) endif() endforeach(input_filename) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS_NON_AVX} ${Caffe2_CPU_SRCS_AVX} ${Caffe2_CPU_SRCS_AVX2}) +set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS_NON_AVX} ${Caffe2_CPU_SRCS_AVX2} ${Caffe2_CPU_SRCS_AVX512}) # ========================================================== # END formerly-libtorch sources diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 19579b9a32b..aeeaf64193c 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -63,14 +63,6 @@ if(INTERN_BUILD_ATEN_OPS) endif() endif(MSVC) - if(C_AVX_FOUND) - if(MSVC) - set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "${OPT_FLAG}/arch:AVX ${CXX_AVX_FLAGS}") - else(MSVC) - set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "${OPT_FLAG} ${CXX_AVX_FLAGS}") - endif(MSVC) - endif(C_AVX_FOUND) - if(NOT MSVC AND NOT "${CMAKE_C_COMPILER_ID}" MATCHES "Clang") set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/MapAllocator.cpp PROPERTIES COMPILE_FLAGS "-fno-openmp") endif() @@ -80,15 +72,16 @@ if(INTERN_BUILD_ATEN_OPS) list(APPEND CPU_CAPABILITY_NAMES "DEFAULT") list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}") - if(CXX_AVX_FOUND) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX_CPU_DEFINITION") - list(APPEND CPU_CAPABILITY_NAMES "AVX") + + if(CXX_AVX512_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX512_CPU_DEFINITION") + list(APPEND CPU_CAPABILITY_NAMES "AVX512") if(MSVC) - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX") + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512") else(MSVC) - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx") + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx512f -mavx512bw -mavx512vl -mavx512dq -mfma") endif(MSVC) - endif(CXX_AVX_FOUND) + endif(CXX_AVX512_FOUND) if(CXX_AVX2_FOUND) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX2_CPU_DEFINITION") @@ -103,11 +96,24 @@ if(INTERN_BUILD_ATEN_OPS) endif(COMPILER_SUPPORTS_NO_AVX256_SPLIT) list(APPEND CPU_CAPABILITY_NAMES "AVX2") - if(MSVC) - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX2") - else(MSVC) - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx2 -mfma ${CPU_NO_AVX256_SPLIT_FLAGS}") - endif(MSVC) + if(DEFINED ENV{ATEN_AVX512_256}) + if($ENV{ATEN_AVX512_256} MATCHES "TRUE") + if(CXX_AVX512_FOUND) + message("-- ATen AVX2 kernels will use 32 ymm registers") + if(MSVC) + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512") + else(MSVC) + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=native ${CPU_NO_AVX256_SPLIT_FLAGS}") + endif(MSVC) + endif(CXX_AVX512_FOUND) + endif() + else() + if(MSVC) + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX2") + else(MSVC) + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx2 -mfma ${CPU_NO_AVX256_SPLIT_FLAGS}") + endif(MSVC) + endif() endif(CXX_AVX2_FOUND) if(CXX_VSX_FOUND) diff --git a/cmake/Modules/FindAVX.cmake b/cmake/Modules/FindAVX.cmake index 7d472eb662c..c04427cbad8 100644 --- a/cmake/Modules/FindAVX.cmake +++ b/cmake/Modules/FindAVX.cmake @@ -12,6 +12,25 @@ SET(AVX_CODE " } ") +SET(AVX512_CODE " + #include + + int main() + { + __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0); + __m512i b = a; + __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ); + return 0; + } +") + SET(AVX2_CODE " #include @@ -56,6 +75,8 @@ ENDMACRO() CHECK_SSE(C "AVX" " ;-mavx;/arch:AVX") CHECK_SSE(C "AVX2" " ;-mavx2 -mfma;/arch:AVX2") +CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512") CHECK_SSE(CXX "AVX" " ;-mavx;/arch:AVX") CHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma;/arch:AVX2") +CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512") diff --git a/setup.py b/setup.py index 5264fe7fa46..73d7d11adae 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,12 @@ # BUILD_BINARY # enables the additional binaries/ build # +# ATEN_AVX512_256=TRUE +# ATen AVX2 kernels can use 32 ymm registers, instead of the default 16. +# This option can be used if AVX512 doesn't perform well on a machine. +# The FBGEMM library also uses AVX512_256 kernels on Xeon D processors, +# but it also has some (optimized) assembly code. +# # PYTORCH_BUILD_VERSION # PYTORCH_BUILD_NUMBER # specify the version of PyTorch, rather than the hard-coded version @@ -928,6 +934,7 @@ if __name__ == '__main__': 'include/ATen/*.h', 'include/ATen/cpu/*.h', 'include/ATen/cpu/vec/vec256/*.h', + 'include/ATen/cpu/vec/vec512/*.h', 'include/ATen/cpu/vec/*.h', 'include/ATen/core/*.h', 'include/ATen/cuda/*.cuh', diff --git a/test/cpp/api/dispatch.cpp b/test/cpp/api/dispatch.cpp index e5bc35177dd..6416fe3e809 100644 --- a/test/cpp/api/dispatch.cpp +++ b/test/cpp/api/dispatch.cpp @@ -29,19 +29,19 @@ TEST_F(DispatchTest, TestAVX2) { } } -TEST_F(DispatchTest, TestAVX) { +TEST_F(DispatchTest, TestAVX512) { const std::vector ints {1, 2, 3, 4}; const std::vector result {1, 4, 27, 256}; const auto vals_tensor = torch::tensor(ints); const auto pows_tensor = torch::tensor(ints); #ifdef _WIN32 - _putenv("ATEN_CPU_CAPABILITY=avx"); + _putenv("ATEN_CPU_CAPABILITY=avx512"); #else - setenv("ATEN_CPU_CAPABILITY", "avx", 1); + setenv("ATEN_CPU_CAPABILITY", "avx512", 1); #endif - const auto actual_pow_avx = vals_tensor.pow(pows_tensor); + const auto actual_pow_avx512 = vals_tensor.pow(pows_tensor); for (int i = 0; i < 4; i++) { - ASSERT_EQ(result[i], actual_pow_avx[i].item()); + ASSERT_EQ(result[i], actual_pow_avx512[i].item()); } } diff --git a/test/quantization/bc/test_backward_compatibility.py b/test/quantization/bc/test_backward_compatibility.py index f9c24c78980..65253869ddc 100644 --- a/test/quantization/bc/test_backward_compatibility.py +++ b/test/quantization/bc/test_backward_compatibility.py @@ -2,7 +2,7 @@ import sys import os - +import unittest # torch import torch import torch.nn as nn @@ -11,7 +11,7 @@ import torch.nn.quantized.dynamic as nnqd import torch.nn.intrinsic.quantized as nniq # Testing utils -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import TestCase, IS_AVX512_VNNI_SUPPORTED from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm def remove_prefix(text, prefix): @@ -238,6 +238,7 @@ class TestSerialization(TestCase): # TODO: graph mode quantized conv3d module @override_qengines + @unittest.skipIf(IS_AVX512_VNNI_SUPPORTED, "This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098") def test_lstm(self): class LSTMModule(torch.nn.Module): def __init__(self): diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 94f72b1ee41..fd30756dca5 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -339,6 +339,15 @@ IS_WINDOWS = sys.platform == "win32" IS_MACOS = sys.platform == "darwin" IS_PPC = platform.machine() == "ppc64le" +def is_avx512_vnni_supported(): + if sys.platform != 'linux': + return False + with open("/proc/cpuinfo", encoding="ascii") as f: + lines = f.read() + return "avx512vnni" in lines + +IS_AVX512_VNNI_SUPPORTED = is_avx512_vnni_supported() + if IS_WINDOWS: @contextmanager def TemporaryFileName(*args, **kwargs):