mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add AVX512 support in ATen & remove AVX support (#61903)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61903 ### Remaining Tasks - [ ] Collate results of benchmarks on two Intel Xeon machines (with & without CUDA, to check if CPU throttling causes issues with GPUs) - make graphs, including Roofline model plots (Intel Advisor can't make them with libgomp, though, but with Intel OpenMP). ### Summary 1. This draft PR produces binaries with with 3 types of ATen kernels - default, AVX2, AVX512 . Using the environment variable `ATEN_AVX512_256=TRUE` also results in 3 types of kernels, but the compiler can use 32 ymm registers for AVX2, instead of the default 16. ATen kernels for `CPU_CAPABILITY_AVX` have been removed. 2. `nansum` is not using AVX512 kernel right now, as it has poorer accuracy for Float16, than does AVX2 or DEFAULT, whose respective accuracies aren't very good either (#59415). It was more convenient to disable AVX512 dispatch for all dtypes of `nansum` for now. 3. On Windows , ATen Quantized AVX512 kernels are not being used, as quantization tests are flaky. If `--continue-through-failure` is used, then `test_compare_model_outputs_functional_static` fails. But if this test is skipped, `test_compare_model_outputs_conv_static` fails. If both these tests are skipped, then a third one fails. These are hard to debug right now due to not having access to a Windows machine with AVX512 support, so it was more convenient to disable AVX512 dispatch of all ATen Quantized kernels on Windows for now. 4. One test is currently being skipped - [test_lstm` in `quantization.bc](https://github.com/pytorch/pytorch/issues/59098) - It fails only on Cascade Lake machines, irrespective of the `ATEN_CPU_CAPABILITY` used, because FBGEMM uses `AVX512_VNNI` on machines that support it. The value of `reduce_range` should be used as `False` on such machines. The list of the changes is at https://gist.github.com/imaginary-person/4b4fda660534f0493bf9573d511a878d. Credits to ezyang for proposing `AVX512_256` - these use AVX2 intrinsics but benefit from 32 registers, instead of the 16 ymm registers that AVX2 uses. Credits to limo1996 for the initial proposal, and for optimizing `hsub_pd` & `hadd_pd`, which didn't have direct AVX512 equivalents, and are being used in some kernels. He also refactored `vec/functional.h` to remove duplicated code. Credits to quickwritereader for helping fix 4 failing complex multiplication & division tests. ### Testing 1. `vec_test_all_types` was modified to test basic AVX512 support, as tests already existed for AVX2. Only one test had to be modified, as it was hardcoded for AVX2. 2. `pytorch_linux_bionic_py3_8_gcc9_coverage_test1` & `pytorch_linux_bionic_py3_8_gcc9_coverage_test2` are now using `linux.2xlarge` instances, as they support AVX512. They were used for testing AVX512 kernels, as AVX512 kernels are being used by default in both of the CI checks. Windows CI checks had already been using machines with AVX512 support. ### Would the downclocking caused by AVX512 pose an issue? I think it's important to note that AVX2 causes downclocking as well, and the additional downclocking caused by AVX512 may not hamper performance on some Skylake machines & beyond, because of the double vector-size. I think that [this post with verifiable references is a must-read](https://community.intel.com/t5/Software-Tuning-Performance/Unexpected-power-vs-cores-profile-for-MKL-kernels-on-modern-Xeon/m-p/1133869/highlight/true#M6450). Also, AVX512 would _probably not_ hurt performance on a high-end machine, [but measurements are recommended](https://lemire.me/blog/2018/09/07/avx-512-when-and-how-to-use-these-new-instructions/). In case it does, `ATEN_AVX512_256=TRUE` can be used for building PyTorch, as AVX2 can then use 32 ymm registers instead of the default 16. [FBGEMM uses `AVX512_256` only on Xeon D processors](https://github.com/pytorch/FBGEMM/pull/209), which are said to have poor AVX512 performance. This [official data](https://www.intel.com/content/dam/www/public/us/en/documents/specification-updates/xeon-scalable-spec-update.pdf) is for the Intel Skylake family, and the first link helps understand its significance. Cascade Lake & Ice Lake SP Xeon processors are said to be even better when it comes to AVX512 performance. Here is the corresponding data for [Cascade Lake](https://cdrdv2.intel.com/v1/dl/getContent/338848) -   The corresponding data isn't publicly available for Intel Xeon SP 3rd gen (Ice Lake SP), but [Intel mentioned that the 3rd gen has frequency improvements pertaining to AVX512](https://newsroom.intel.com/wp-content/uploads/sites/11/2021/04/3rd-Gen-Intel-Xeon-Scalable-Platform-Press-Presentation-281884.pdf). Ice Lake SP machines also have 48 KB L1D caches, so that's another reason for AVX512 performance to be better on them. ### Is PyTorch always faster with AVX512? No, but then PyTorch is not always faster with AVX2 either. Please refer to #60202. The benefit from vectorization is apparent with with small tensors that fit in caches or in kernels that are more compute heavy. For instance, AVX512 or AVX2 would yield no benefit for adding two 64 MB tensors, but adding two 1 MB tensors would do well with AVX2, and even more so with AVX512. It seems that memory-bound computations, such as adding two 64 MB tensors can be slow with vectorization (depending upon the number of threads used), as the effects of downclocking can then be observed. Original pull request: https://github.com/pytorch/pytorch/pull/56992 Reviewed By: soulitzer Differential Revision: D29266289 Pulled By: ezyang fbshipit-source-id: 2d5e8d1c2307252f22423bbc14f136c67c3e6184
This commit is contained in:
parent
59d6e07ada
commit
9e53c823b8
|
|
@ -132,7 +132,9 @@ fi
|
||||||
if [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX-* || $TEST_CONFIG == 'nogpu_NO_AVX' ]]; then
|
if [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX-* || $TEST_CONFIG == 'nogpu_NO_AVX' ]]; then
|
||||||
export ATEN_CPU_CAPABILITY=default
|
export ATEN_CPU_CAPABILITY=default
|
||||||
elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX2-* || $TEST_CONFIG == 'nogpu_NO_AVX2' ]]; then
|
elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX2-* || $TEST_CONFIG == 'nogpu_NO_AVX2' ]]; then
|
||||||
export ATEN_CPU_CAPABILITY=avx
|
export ATEN_CPU_CAPABILITY=default
|
||||||
|
elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX512-* || $TEST_CONFIG == 'nogpu_NO_AVX512' ]]; then
|
||||||
|
export ATEN_CPU_CAPABILITY=avx2
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "$IN_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]; then
|
if [ -n "$IN_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]; then
|
||||||
|
|
|
||||||
3
aten.bzl
3
aten.bzl
|
|
@ -1,9 +1,8 @@
|
||||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
load("@rules_cc//cc:defs.bzl", "cc_library")
|
||||||
|
|
||||||
CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX", "AVX2"]
|
CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"]
|
||||||
CAPABILITY_COMPILER_FLAGS = {
|
CAPABILITY_COMPILER_FLAGS = {
|
||||||
"AVX2": ["-mavx2", "-mfma"],
|
"AVX2": ["-mavx2", "-mfma"],
|
||||||
"AVX": ["-mavx"],
|
|
||||||
"DEFAULT": [],
|
"DEFAULT": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ if(NOT BUILD_LITE_INTERPRETER)
|
||||||
endif()
|
endif()
|
||||||
EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
|
EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
|
||||||
|
|
||||||
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h")
|
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h")
|
||||||
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp")
|
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp")
|
||||||
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh")
|
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh")
|
||||||
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp")
|
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp")
|
||||||
|
|
|
||||||
|
|
@ -108,12 +108,12 @@ std::string used_cpu_capability() {
|
||||||
case native::CPUCapability::DEFAULT:
|
case native::CPUCapability::DEFAULT:
|
||||||
ss << "NO AVX";
|
ss << "NO AVX";
|
||||||
break;
|
break;
|
||||||
case native::CPUCapability::AVX:
|
|
||||||
ss << "AVX";
|
|
||||||
break;
|
|
||||||
case native::CPUCapability::AVX2:
|
case native::CPUCapability::AVX2:
|
||||||
ss << "AVX2";
|
ss << "AVX2";
|
||||||
break;
|
break;
|
||||||
|
case native::CPUCapability::AVX512:
|
||||||
|
ss << "AVX512";
|
||||||
|
break;
|
||||||
#endif
|
#endif
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
#include <ATen/cpu/FlushDenormal.h>
|
#include <ATen/cpu/FlushDenormal.h>
|
||||||
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
|
||||||
#include <cpuinfo.h>
|
#include <cpuinfo.h>
|
||||||
|
|
||||||
namespace at { namespace cpu {
|
namespace at { namespace cpu {
|
||||||
|
|
|
||||||
|
|
@ -1 +1,6 @@
|
||||||
#include <ATen/cpu/vec/vec256/functional.h>
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/cpu/vec/functional_base.h>
|
||||||
|
#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX)
|
||||||
|
#include <ATen/cpu/vec/functional_bfloat16.h>
|
||||||
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/vec256.h>
|
#include <ATen/cpu/vec/vec.h>
|
||||||
|
|
||||||
namespace at { namespace vec {
|
namespace at { namespace vec {
|
||||||
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/functional_base.h>
|
#include <ATen/cpu/vec/vec.h>
|
||||||
|
|
||||||
namespace at { namespace vec {
|
namespace at { namespace vec {
|
||||||
|
|
||||||
|
|
@ -15,26 +15,26 @@ template <> struct VecScalarType<BFloat16> { using type = float; };
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
using vec_scalar_t = typename VecScalarType<scalar_t>::type;
|
using vec_scalar_t = typename VecScalarType<scalar_t>::type;
|
||||||
|
|
||||||
// Note that we already have specializes member of Vectorized<scalar_t> for BFloat16
|
// Note that we already have specialized member of Vectorized<scalar_t> for BFloat16
|
||||||
// so the following function would run smoothly:
|
// so the following functions would run smoothly:
|
||||||
// using Vec = Vectorized<BFloat16>;
|
// using Vec = Vectorized<BFloat16>;
|
||||||
// Vec one = Vec(BFloat16(1));
|
// Vec one = Vec(BFloat16(1));
|
||||||
// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
|
// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
|
||||||
//
|
//
|
||||||
// Why we still need to specializes "funtional"?
|
// Then why we still need to specialize "funtional"?
|
||||||
// If we do specialization at Vectorized<> level, the above example would need 3 pairs of
|
// If we do specialization at Vectorized<> level, the above example would need 3 pairs of
|
||||||
// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/".
|
// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/".
|
||||||
// If we do specialization at vec::map<>() level, we have only 1 pair of conversion
|
// If we do specialization at vec::map<>() level, we have only 1 pair of conversion
|
||||||
// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only.
|
// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only.
|
||||||
//
|
//
|
||||||
// The following BFloat16 functionalities will only do data type conversion for input
|
// The following BFloat16 functionality will only do data type conversion for input
|
||||||
// and output vector (reduce functionalities will only convert the final scalar back to bf16).
|
// and output vector (reduce functionality will only convert the final scalar back to bf16).
|
||||||
// Compared to Vectorized<> specialization,
|
// Compared to Vectorized<> specialization,
|
||||||
// 1. better performance since we have less data type conversion;
|
// 1. better performance since we have less data type conversion;
|
||||||
// 2. less rounding error since immediate results are kept in fp32;
|
// 2. less rounding error since immediate results are kept in fp32;
|
||||||
// 3. accumulation done on data type of fp32.
|
// 3. accumulation done on data type of fp32.
|
||||||
//
|
//
|
||||||
// If you plan to extend this file, make sure add unit test at
|
// If you plan to extend this file, please ensure adding unit tests at
|
||||||
// aten/src/ATen/test/vec_test_all_types.cpp
|
// aten/src/ATen/test/vec_test_all_types.cpp
|
||||||
//
|
//
|
||||||
template <typename scalar_t = BFloat16, typename Op>
|
template <typename scalar_t = BFloat16, typename Op>
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__))
|
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
|
||||||
/* Clang-compatible compiler, targeting x86/x86-64 */
|
/* GCC or clang-compatible compiler, targeting x86/x86-64 */
|
||||||
#include <x86intrin.h>
|
#include <x86intrin.h>
|
||||||
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
||||||
/* Clang-compatible compiler, targeting arm neon */
|
/* Clang-compatible compiler, targeting arm neon */
|
||||||
|
|
@ -14,9 +14,6 @@
|
||||||
#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
|
#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
|
||||||
#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
|
#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
|
||||||
#endif
|
#endif
|
||||||
#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
|
|
||||||
/* GCC-compatible compiler, targeting x86/x86-64 */
|
|
||||||
#include <x86intrin.h>
|
|
||||||
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
||||||
/* GCC-compatible compiler, targeting ARM with NEON */
|
/* GCC-compatible compiler, targeting ARM with NEON */
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
|
|
@ -1 +1,5 @@
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
#include <ATen/cpu/vec/vec512/vec512.h>
|
||||||
|
#else
|
||||||
#include <ATen/cpu/vec/vec256/vec256.h>
|
#include <ATen/cpu/vec/vec256/vec256.h>
|
||||||
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/functional_base.h>
|
|
||||||
#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX)
|
|
||||||
#include <ATen/cpu/vec/vec256/functional_bfloat16.h>
|
|
||||||
#endif
|
|
||||||
|
|
@ -3,9 +3,9 @@
|
||||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX)
|
#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX)
|
||||||
#include <ATen/cpu/vec/vec256/vec256_float.h>
|
#include <ATen/cpu/vec/vec256/vec256_float.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_float_neon.h>
|
#include <ATen/cpu/vec/vec256/vec256_float_neon.h>
|
||||||
|
|
@ -68,9 +68,9 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
|
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
|
||||||
|
|
@ -82,29 +82,6 @@ inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
|
||||||
return _mm256_castps_pd(src);
|
return _mm256_castps_pd(src);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(CPU_CAPABILITY_AVX2)
|
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
#define DEFINE_FLOAT_INT_CAST(int_t, float_t, float_ch) \
|
|
||||||
template<> \
|
|
||||||
inline Vectorized<int_t> cast<int_t, float_t>(const Vectorized<float_t>& src) { \
|
|
||||||
return _mm256_castp ## float_ch ## _si256(src); \
|
|
||||||
} \
|
|
||||||
template<> \
|
|
||||||
inline Vectorized<float_t> cast<float_t, int_t>(const Vectorized<int_t>& src) { \
|
|
||||||
return _mm256_castsi256_p ## float_ch (src); \
|
|
||||||
}
|
|
||||||
|
|
||||||
DEFINE_FLOAT_INT_CAST(int64_t, double, d)
|
|
||||||
DEFINE_FLOAT_INT_CAST(int32_t, double, d)
|
|
||||||
DEFINE_FLOAT_INT_CAST(int16_t, double, d)
|
|
||||||
DEFINE_FLOAT_INT_CAST(int64_t, float, s)
|
|
||||||
DEFINE_FLOAT_INT_CAST(int32_t, float, s)
|
|
||||||
DEFINE_FLOAT_INT_CAST(int16_t, float, s)
|
|
||||||
|
|
||||||
#undef DEFINE_FLOAT_INT_CAST
|
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
template<int64_t scale = 1>
|
template<int64_t scale = 1>
|
||||||
|
|
@ -243,8 +220,6 @@ inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>&
|
||||||
_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart
|
_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // defined(CPU_CAPABILITY_AVX2)
|
#endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
|
|
||||||
#endif // (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
|
||||||
|
|
||||||
}}}
|
}}}
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
||||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
#include <sleef.h>
|
#include <sleef.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -100,7 +100,7 @@ public:
|
||||||
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||||
}
|
}
|
||||||
static Vectorized<BFloat16> loadu(const void* ptr, int16_t count) {
|
static Vectorized<BFloat16> loadu(const void* ptr, int16_t count) {
|
||||||
__at_align32__ int16_t tmp_values[size()];
|
__at_align__ int16_t tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
|
std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
|
||||||
return loadu(tmp_values);
|
return loadu(tmp_values);
|
||||||
}
|
}
|
||||||
|
|
@ -108,14 +108,14 @@ public:
|
||||||
if (count == size()) {
|
if (count == size()) {
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ int16_t tmp_values[size()];
|
__at_align__ int16_t tmp_values[size()];
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
|
||||||
std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
|
std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template <int64_t mask>
|
template <int64_t mask>
|
||||||
static Vectorized<BFloat16> blend(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
static Vectorized<BFloat16> blend(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
__at_align32__ int16_t tmp_values[size()];
|
__at_align__ int16_t tmp_values[size()];
|
||||||
a.store(tmp_values);
|
a.store(tmp_values);
|
||||||
if (mask & 0x01)
|
if (mask & 0x01)
|
||||||
tmp_values[0] = _mm256_extract_epi16(b.values, 0);
|
tmp_values[0] = _mm256_extract_epi16(b.values, 0);
|
||||||
|
|
@ -280,7 +280,7 @@ public:
|
||||||
Vectorized<BFloat16> erfinv() const {
|
Vectorized<BFloat16> erfinv() const {
|
||||||
__m256 lo, hi;
|
__m256 lo, hi;
|
||||||
cvtbf16_fp32(values, lo, hi);
|
cvtbf16_fp32(values, lo, hi);
|
||||||
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2];
|
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||||
for (int64_t i = 0; i < size() / 2; i++) {
|
for (int64_t i = 0; i < size() / 2; i++) {
|
||||||
|
|
@ -318,7 +318,7 @@ public:
|
||||||
Vectorized<BFloat16> i0() const {
|
Vectorized<BFloat16> i0() const {
|
||||||
__m256 lo, hi;
|
__m256 lo, hi;
|
||||||
cvtbf16_fp32(values, lo, hi);
|
cvtbf16_fp32(values, lo, hi);
|
||||||
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2];
|
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||||
for (int64_t i = 0; i < size() / 2; i++) {
|
for (int64_t i = 0; i < size() / 2; i++) {
|
||||||
|
|
@ -333,7 +333,7 @@ public:
|
||||||
__m256 lo, hi;
|
__m256 lo, hi;
|
||||||
cvtbf16_fp32(values, lo, hi);
|
cvtbf16_fp32(values, lo, hi);
|
||||||
constexpr auto sz = size();
|
constexpr auto sz = size();
|
||||||
__at_align32__ float tmp1[sz / 2], tmp2[sz / 2];
|
__at_align__ float tmp1[sz / 2], tmp2[sz / 2];
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||||
|
|
||||||
|
|
@ -350,10 +350,10 @@ public:
|
||||||
__m256 xlo, xhi;
|
__m256 xlo, xhi;
|
||||||
cvtbf16_fp32(values, lo, hi);
|
cvtbf16_fp32(values, lo, hi);
|
||||||
cvtbf16_fp32(x.values, xlo, xhi);
|
cvtbf16_fp32(x.values, xlo, xhi);
|
||||||
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2];
|
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||||
__at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2];
|
__at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
|
||||||
for (int64_t i = 0; i < size() / 2; ++i) {
|
for (int64_t i = 0; i < size() / 2; ++i) {
|
||||||
|
|
@ -370,10 +370,10 @@ public:
|
||||||
__m256 xlo, xhi;
|
__m256 xlo, xhi;
|
||||||
cvtbf16_fp32(values, lo, hi);
|
cvtbf16_fp32(values, lo, hi);
|
||||||
cvtbf16_fp32(x.values, xlo, xhi);
|
cvtbf16_fp32(x.values, xlo, xhi);
|
||||||
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2];
|
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||||
__at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2];
|
__at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
|
||||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
|
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
|
||||||
for (int64_t i = 0; i < size() / 2; ++i) {
|
for (int64_t i = 0; i < size() / 2; ++i) {
|
||||||
|
|
@ -717,12 +717,13 @@ inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, c
|
||||||
return cvtfp32_bf16(__m256(a), __m256(b));
|
return cvtfp32_bf16(__m256(a), __m256(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
#else //defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
|
||||||
|
#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
|
|
||||||
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) {
|
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) {
|
||||||
constexpr int64_t K = Vectorized<BFloat16>::size();
|
constexpr int64_t K = Vectorized<BFloat16>::size();
|
||||||
__at_align32__ float arr[K];
|
__at_align__ float arr[K];
|
||||||
__at_align32__ BFloat16 arr2[K];
|
__at_align__ BFloat16 arr2[K];
|
||||||
a.store(arr2);
|
a.store(arr2);
|
||||||
convert(arr2, arr, K);
|
convert(arr2, arr, K);
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
|
|
@ -732,15 +733,15 @@ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(c
|
||||||
|
|
||||||
inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) {
|
inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
constexpr int64_t K = Vectorized<BFloat16>::size();
|
constexpr int64_t K = Vectorized<BFloat16>::size();
|
||||||
__at_align32__ float arr[K];
|
__at_align__ float arr[K];
|
||||||
__at_align32__ BFloat16 arr2[K];
|
__at_align__ BFloat16 arr2[K];
|
||||||
a.store(arr);
|
a.store(arr);
|
||||||
b.store(arr + Vectorized<float>::size());
|
b.store(arr + Vectorized<float>::size());
|
||||||
convert(arr, arr2, K);
|
convert(arr, arr2, K);
|
||||||
return Vectorized<BFloat16>::loadu(arr2);
|
return Vectorized<BFloat16>::loadu(arr2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
|
|
||||||
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
|
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
|
||||||
|
|
@ -759,7 +760,7 @@ void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out1, Vec
|
||||||
}
|
}
|
||||||
#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
|
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
|
||||||
__at_align32__ float values[Vectorized<float>::size()];
|
__at_align__ float values[Vectorized<float>::size()];
|
||||||
for (int k = 0; k < Vectorized<float>::size(); ++k) {
|
for (int k = 0; k < Vectorized<float>::size(); ++k) {
|
||||||
values[k] = data[k];
|
values[k] = data[k];
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,10 @@
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <c10/util/complex.h>
|
#include <c10/util/complex.h>
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
#include <sleef.h>
|
#include <sleef.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -15,7 +16,7 @@ namespace vec {
|
||||||
// See Note [Acceptable use of anonymous namespace in header]
|
// See Note [Acceptable use of anonymous namespace in header]
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
|
|
||||||
template <> class Vectorized<c10::complex<double>> {
|
template <> class Vectorized<c10::complex<double>> {
|
||||||
private:
|
private:
|
||||||
|
|
@ -81,7 +82,7 @@ public:
|
||||||
if (count == size())
|
if (count == size())
|
||||||
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
|
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
|
||||||
|
|
||||||
__at_align32__ double tmp_values[2*size()];
|
__at_align__ double tmp_values[2*size()];
|
||||||
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
// instructions while a loop would be compiled to one instruction.
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
|
@ -106,7 +107,7 @@ public:
|
||||||
const c10::complex<double>& operator[](int idx) const = delete;
|
const c10::complex<double>& operator[](int idx) const = delete;
|
||||||
c10::complex<double>& operator[](int idx) = delete;
|
c10::complex<double>& operator[](int idx) = delete;
|
||||||
Vectorized<c10::complex<double>> map(c10::complex<double> (*const f)(const c10::complex<double> &)) const {
|
Vectorized<c10::complex<double>> map(c10::complex<double> (*const f)(const c10::complex<double> &)) const {
|
||||||
__at_align32__ c10::complex<double> tmp[size()];
|
__at_align__ c10::complex<double> tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
tmp[i] = f(tmp[i]);
|
tmp[i] = f(tmp[i]);
|
||||||
|
|
@ -288,8 +289,8 @@ public:
|
||||||
return sqrt().reciprocal();
|
return sqrt().reciprocal();
|
||||||
}
|
}
|
||||||
Vectorized<c10::complex<double>> pow(const Vectorized<c10::complex<double>> &exp) const {
|
Vectorized<c10::complex<double>> pow(const Vectorized<c10::complex<double>> &exp) const {
|
||||||
__at_align32__ c10::complex<double> x_tmp[size()];
|
__at_align__ c10::complex<double> x_tmp[size()];
|
||||||
__at_align32__ c10::complex<double> y_tmp[size()];
|
__at_align__ c10::complex<double> y_tmp[size()];
|
||||||
store(x_tmp);
|
store(x_tmp);
|
||||||
exp.store(y_tmp);
|
exp.store(y_tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,9 @@
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <c10/util/complex.h>
|
#include <c10/util/complex.h>
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
#include <sleef.h>
|
#include <sleef.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -15,7 +15,7 @@ namespace vec {
|
||||||
// See Note [Acceptable use of anonymous namespace in header]
|
// See Note [Acceptable use of anonymous namespace in header]
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
|
|
||||||
template <> class Vectorized<c10::complex<float>> {
|
template <> class Vectorized<c10::complex<float>> {
|
||||||
private:
|
private:
|
||||||
|
|
@ -117,7 +117,7 @@ public:
|
||||||
if (count == size())
|
if (count == size())
|
||||||
return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
|
return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
|
||||||
|
|
||||||
__at_align32__ float tmp_values[2*size()];
|
__at_align__ float tmp_values[2*size()];
|
||||||
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
// instructions while a loop would be compiled to one instruction.
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
|
@ -142,7 +142,7 @@ public:
|
||||||
const c10::complex<float>& operator[](int idx) const = delete;
|
const c10::complex<float>& operator[](int idx) const = delete;
|
||||||
c10::complex<float>& operator[](int idx) = delete;
|
c10::complex<float>& operator[](int idx) = delete;
|
||||||
Vectorized<c10::complex<float>> map(c10::complex<float> (*const f)(const c10::complex<float> &)) const {
|
Vectorized<c10::complex<float>> map(c10::complex<float> (*const f)(const c10::complex<float> &)) const {
|
||||||
__at_align32__ c10::complex<float> tmp[size()];
|
__at_align__ c10::complex<float> tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
tmp[i] = f(tmp[i]);
|
tmp[i] = f(tmp[i]);
|
||||||
|
|
@ -323,8 +323,8 @@ public:
|
||||||
return sqrt().reciprocal();
|
return sqrt().reciprocal();
|
||||||
}
|
}
|
||||||
Vectorized<c10::complex<float>> pow(const Vectorized<c10::complex<float>> &exp) const {
|
Vectorized<c10::complex<float>> pow(const Vectorized<c10::complex<float>> &exp) const {
|
||||||
__at_align32__ c10::complex<float> x_tmp[size()];
|
__at_align__ c10::complex<float> x_tmp[size()];
|
||||||
__at_align32__ c10::complex<float> y_tmp[size()];
|
__at_align__ c10::complex<float> y_tmp[size()];
|
||||||
store(x_tmp);
|
store(x_tmp);
|
||||||
exp.store(y_tmp);
|
exp.store(y_tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@
|
||||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
#include <sleef.h>
|
#include <sleef.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -14,7 +14,8 @@ namespace vec {
|
||||||
// See Note [Acceptable use of anonymous namespace in header]
|
// See Note [Acceptable use of anonymous namespace in header]
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
|
|
||||||
template <> class Vectorized<double> {
|
template <> class Vectorized<double> {
|
||||||
private:
|
private:
|
||||||
|
|
@ -67,7 +68,7 @@ public:
|
||||||
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
|
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
|
||||||
|
|
||||||
|
|
||||||
__at_align32__ double tmp_values[size()];
|
__at_align__ double tmp_values[size()];
|
||||||
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
// instructions while a loop would be compiled to one instruction.
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
|
@ -100,7 +101,7 @@ public:
|
||||||
return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
|
return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
|
||||||
}
|
}
|
||||||
Vectorized<double> map(double (*const f)(double)) const {
|
Vectorized<double> map(double (*const f)(double)) const {
|
||||||
__at_align32__ double tmp[size()];
|
__at_align__ double tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
tmp[i] = f(tmp[i]);
|
tmp[i] = f(tmp[i]);
|
||||||
|
|
@ -175,8 +176,8 @@ public:
|
||||||
return map(calc_i0e);
|
return map(calc_i0e);
|
||||||
}
|
}
|
||||||
Vectorized<double> igamma(const Vectorized<double> &x) const {
|
Vectorized<double> igamma(const Vectorized<double> &x) const {
|
||||||
__at_align32__ double tmp[size()];
|
__at_align__ double tmp[size()];
|
||||||
__at_align32__ double tmp_x[size()];
|
__at_align__ double tmp_x[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
x.store(tmp_x);
|
x.store(tmp_x);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
@ -185,8 +186,8 @@ public:
|
||||||
return loadu(tmp);
|
return loadu(tmp);
|
||||||
}
|
}
|
||||||
Vectorized<double> igammac(const Vectorized<double> &x) const {
|
Vectorized<double> igammac(const Vectorized<double> &x) const {
|
||||||
__at_align32__ double tmp[size()];
|
__at_align__ double tmp[size()];
|
||||||
__at_align32__ double tmp_x[size()];
|
__at_align__ double tmp_x[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
x.store(tmp_x);
|
x.store(tmp_x);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@
|
||||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
#include <sleef.h>
|
#include <sleef.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -14,7 +14,7 @@ namespace vec {
|
||||||
// See Note [Acceptable use of anonymous namespace in header]
|
// See Note [Acceptable use of anonymous namespace in header]
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
|
|
||||||
template <> class Vectorized<float> {
|
template <> class Vectorized<float> {
|
||||||
private:
|
private:
|
||||||
|
|
@ -76,7 +76,7 @@ public:
|
||||||
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
|
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
|
||||||
if (count == size())
|
if (count == size())
|
||||||
return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
|
return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
|
||||||
__at_align32__ float tmp_values[size()];
|
__at_align__ float tmp_values[size()];
|
||||||
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
// instructions while a loop would be compiled to one instruction.
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
|
@ -107,7 +107,7 @@ public:
|
||||||
return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q);
|
return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q);
|
||||||
}
|
}
|
||||||
Vectorized<float> map(float (*const f)(float)) const {
|
Vectorized<float> map(float (*const f)(float)) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
tmp[i] = f(tmp[i]);
|
tmp[i] = f(tmp[i]);
|
||||||
|
|
@ -213,8 +213,8 @@ public:
|
||||||
return map(calc_i0e);
|
return map(calc_i0e);
|
||||||
}
|
}
|
||||||
Vectorized<float> igamma(const Vectorized<float> &x) const {
|
Vectorized<float> igamma(const Vectorized<float> &x) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float tmp_x[size()];
|
__at_align__ float tmp_x[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
x.store(tmp_x);
|
x.store(tmp_x);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
@ -223,8 +223,8 @@ public:
|
||||||
return loadu(tmp);
|
return loadu(tmp);
|
||||||
}
|
}
|
||||||
Vectorized<float> igammac(const Vectorized<float> &x) const {
|
Vectorized<float> igammac(const Vectorized<float> &x) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float tmp_x[size()];
|
__at_align__ float tmp_x[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
x.store(tmp_x);
|
x.store(tmp_x);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
@ -412,12 +412,11 @@ inline void convert(const float* src, float* dst, int64_t n) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
template <>
|
template <>
|
||||||
Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
|
Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
|
||||||
return _mm256_fmadd_ps(a, b, c);
|
return _mm256_fmadd_ps(a, b, c);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
||||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
// Sleef offers vectorized versions of some transcedentals
|
// Sleef offers vectorized versions of some transcedentals
|
||||||
// such as sin, cos, tan etc..
|
// such as sin, cos, tan etc..
|
||||||
// However for now opting for STL, since we are not building
|
// However for now opting for STL, since we are not building
|
||||||
|
|
@ -220,7 +220,7 @@ public:
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
__at_align32__ float tmp_values[size()];
|
__at_align__ float tmp_values[size()];
|
||||||
for (auto i = 0; i < size(); ++i) {
|
for (auto i = 0; i < size(); ++i) {
|
||||||
tmp_values[i] = 0.0;
|
tmp_values[i] = 0.0;
|
||||||
}
|
}
|
||||||
|
|
@ -261,19 +261,19 @@ public:
|
||||||
// Once we specialize that implementation for ARM
|
// Once we specialize that implementation for ARM
|
||||||
// this should be removed. TODO (kimishpatel)
|
// this should be removed. TODO (kimishpatel)
|
||||||
float operator[](int idx) const {
|
float operator[](int idx) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
return tmp[idx];
|
return tmp[idx];
|
||||||
}
|
}
|
||||||
float operator[](int idx) {
|
float operator[](int idx) {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
return tmp[idx];
|
return tmp[idx];
|
||||||
}
|
}
|
||||||
// For boolean version where we want to if any 1/all zero
|
// For boolean version where we want to if any 1/all zero
|
||||||
// etc. can be done faster in a different way.
|
// etc. can be done faster in a different way.
|
||||||
int zero_mask() const {
|
int zero_mask() const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
int mask = 0;
|
int mask = 0;
|
||||||
for (int i = 0; i < size(); ++ i) {
|
for (int i = 0; i < size(); ++ i) {
|
||||||
|
|
@ -284,8 +284,8 @@ public:
|
||||||
return mask;
|
return mask;
|
||||||
}
|
}
|
||||||
Vectorized<float> isnan() const {
|
Vectorized<float> isnan() const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float res[size()];
|
__at_align__ float res[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
if (_isnan(tmp[i])) {
|
if (_isnan(tmp[i])) {
|
||||||
|
|
@ -297,7 +297,7 @@ public:
|
||||||
return loadu(res);
|
return loadu(res);
|
||||||
};
|
};
|
||||||
Vectorized<float> map(float (*const f)(float)) const {
|
Vectorized<float> map(float (*const f)(float)) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
tmp[i] = f(tmp[i]);
|
tmp[i] = f(tmp[i]);
|
||||||
|
|
@ -332,8 +332,8 @@ public:
|
||||||
return map(std::atan);
|
return map(std::atan);
|
||||||
}
|
}
|
||||||
Vectorized<float> atan2(const Vectorized<float> &exp) const {
|
Vectorized<float> atan2(const Vectorized<float> &exp) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float tmp_exp[size()];
|
__at_align__ float tmp_exp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
exp.store(tmp_exp);
|
exp.store(tmp_exp);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
@ -342,8 +342,8 @@ public:
|
||||||
return loadu(tmp);
|
return loadu(tmp);
|
||||||
}
|
}
|
||||||
Vectorized<float> copysign(const Vectorized<float> &sign) const {
|
Vectorized<float> copysign(const Vectorized<float> &sign) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float tmp_sign[size()];
|
__at_align__ float tmp_sign[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
sign.store(tmp_sign);
|
sign.store(tmp_sign);
|
||||||
for (size_type i = 0; i < size(); i++) {
|
for (size_type i = 0; i < size(); i++) {
|
||||||
|
|
@ -367,8 +367,8 @@ public:
|
||||||
return map(std::expm1);
|
return map(std::expm1);
|
||||||
}
|
}
|
||||||
Vectorized<float> fmod(const Vectorized<float>& q) const {
|
Vectorized<float> fmod(const Vectorized<float>& q) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float tmp_q[size()];
|
__at_align__ float tmp_q[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
q.store(tmp_q);
|
q.store(tmp_q);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
@ -377,8 +377,8 @@ public:
|
||||||
return loadu(tmp);
|
return loadu(tmp);
|
||||||
}
|
}
|
||||||
Vectorized<float> hypot(const Vectorized<float> &b) const {
|
Vectorized<float> hypot(const Vectorized<float> &b) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float tmp_b[size()];
|
__at_align__ float tmp_b[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
b.store(tmp_b);
|
b.store(tmp_b);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
@ -393,8 +393,8 @@ public:
|
||||||
return map(calc_i0e);
|
return map(calc_i0e);
|
||||||
}
|
}
|
||||||
Vectorized<float> igamma(const Vectorized<float> &x) const {
|
Vectorized<float> igamma(const Vectorized<float> &x) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float tmp_x[size()];
|
__at_align__ float tmp_x[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
x.store(tmp_x);
|
x.store(tmp_x);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
@ -403,8 +403,8 @@ public:
|
||||||
return loadu(tmp);
|
return loadu(tmp);
|
||||||
}
|
}
|
||||||
Vectorized<float> igammac(const Vectorized<float> &x) const {
|
Vectorized<float> igammac(const Vectorized<float> &x) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float tmp_x[size()];
|
__at_align__ float tmp_x[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
x.store(tmp_x);
|
x.store(tmp_x);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
@ -425,8 +425,8 @@ public:
|
||||||
return map(std::log2);
|
return map(std::log2);
|
||||||
}
|
}
|
||||||
Vectorized<float> nextafter(const Vectorized<float> &b) const {
|
Vectorized<float> nextafter(const Vectorized<float> &b) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float tmp_b[size()];
|
__at_align__ float tmp_b[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
b.store(tmp_b);
|
b.store(tmp_b);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
@ -490,8 +490,8 @@ public:
|
||||||
return this->sqrt().reciprocal();
|
return this->sqrt().reciprocal();
|
||||||
}
|
}
|
||||||
Vectorized<float> pow(const Vectorized<float> &exp) const {
|
Vectorized<float> pow(const Vectorized<float> &exp) const {
|
||||||
__at_align32__ float tmp[size()];
|
__at_align__ float tmp[size()];
|
||||||
__at_align32__ float tmp_exp[size()];
|
__at_align__ float tmp_exp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
exp.store(tmp_exp);
|
exp.store(tmp_exp);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
||||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
@ -55,7 +55,7 @@ public:
|
||||||
}
|
}
|
||||||
template <int64_t mask>
|
template <int64_t mask>
|
||||||
static Vectorized<int64_t> blend(Vectorized<int64_t> a, Vectorized<int64_t> b) {
|
static Vectorized<int64_t> blend(Vectorized<int64_t> a, Vectorized<int64_t> b) {
|
||||||
__at_align32__ int64_t tmp_values[size()];
|
__at_align__ int64_t tmp_values[size()];
|
||||||
a.store(tmp_values);
|
a.store(tmp_values);
|
||||||
if (mask & 0x01)
|
if (mask & 0x01)
|
||||||
tmp_values[0] = _mm256_extract_epi64(b.values, 0);
|
tmp_values[0] = _mm256_extract_epi64(b.values, 0);
|
||||||
|
|
@ -93,7 +93,7 @@ public:
|
||||||
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||||
}
|
}
|
||||||
static Vectorized<int64_t> loadu(const void* ptr, int64_t count) {
|
static Vectorized<int64_t> loadu(const void* ptr, int64_t count) {
|
||||||
__at_align32__ int64_t tmp_values[size()];
|
__at_align__ int64_t tmp_values[size()];
|
||||||
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
// instructions while a loop would be compiled to one instruction.
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
|
@ -109,7 +109,7 @@ public:
|
||||||
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
|
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ int64_t tmp_values[size()];
|
__at_align__ int64_t tmp_values[size()];
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
|
||||||
std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
|
std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
|
||||||
}
|
}
|
||||||
|
|
@ -216,7 +216,7 @@ public:
|
||||||
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||||
}
|
}
|
||||||
static Vectorized<int32_t> loadu(const void* ptr, int32_t count) {
|
static Vectorized<int32_t> loadu(const void* ptr, int32_t count) {
|
||||||
__at_align32__ int32_t tmp_values[size()];
|
__at_align__ int32_t tmp_values[size()];
|
||||||
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
// instructions while a loop would be compiled to one instruction.
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
|
@ -232,7 +232,7 @@ public:
|
||||||
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
|
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ int32_t tmp_values[size()];
|
__at_align__ int32_t tmp_values[size()];
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
|
||||||
std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
|
std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
|
||||||
}
|
}
|
||||||
|
|
@ -346,7 +346,7 @@ public:
|
||||||
}
|
}
|
||||||
template <int64_t mask>
|
template <int64_t mask>
|
||||||
static Vectorized<int16_t> blend(Vectorized<int16_t> a, Vectorized<int16_t> b) {
|
static Vectorized<int16_t> blend(Vectorized<int16_t> a, Vectorized<int16_t> b) {
|
||||||
__at_align32__ int16_t tmp_values[size()];
|
__at_align__ int16_t tmp_values[size()];
|
||||||
a.store(tmp_values);
|
a.store(tmp_values);
|
||||||
if (mask & 0x01)
|
if (mask & 0x01)
|
||||||
tmp_values[0] = _mm256_extract_epi16(b.values, 0);
|
tmp_values[0] = _mm256_extract_epi16(b.values, 0);
|
||||||
|
|
@ -436,7 +436,7 @@ public:
|
||||||
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||||
}
|
}
|
||||||
static Vectorized<int16_t> loadu(const void* ptr, int16_t count) {
|
static Vectorized<int16_t> loadu(const void* ptr, int16_t count) {
|
||||||
__at_align32__ int16_t tmp_values[size()];
|
__at_align__ int16_t tmp_values[size()];
|
||||||
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
// instructions while a loop would be compiled to one instruction.
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
|
@ -452,7 +452,7 @@ public:
|
||||||
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
|
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ int16_t tmp_values[size()];
|
__at_align__ int16_t tmp_values[size()];
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
|
||||||
std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
|
std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
|
||||||
}
|
}
|
||||||
|
|
@ -527,7 +527,7 @@ public:
|
||||||
}
|
}
|
||||||
template <int64_t mask>
|
template <int64_t mask>
|
||||||
static Vectorized<int8_t> blend(Vectorized<int8_t> a, Vectorized<int8_t> b) {
|
static Vectorized<int8_t> blend(Vectorized<int8_t> a, Vectorized<int8_t> b) {
|
||||||
__at_align32__ int8_t tmp_values[size()];
|
__at_align__ int8_t tmp_values[size()];
|
||||||
a.store(tmp_values);
|
a.store(tmp_values);
|
||||||
if (mask & 0x01)
|
if (mask & 0x01)
|
||||||
tmp_values[0] = _mm256_extract_epi8(b.values, 0);
|
tmp_values[0] = _mm256_extract_epi8(b.values, 0);
|
||||||
|
|
@ -685,7 +685,7 @@ public:
|
||||||
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||||
}
|
}
|
||||||
static Vectorized<int8_t> loadu(const void* ptr, int8_t count) {
|
static Vectorized<int8_t> loadu(const void* ptr, int8_t count) {
|
||||||
__at_align32__ int8_t tmp_values[size()];
|
__at_align__ int8_t tmp_values[size()];
|
||||||
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
// instructions while a loop would be compiled to one instruction.
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
|
@ -701,7 +701,7 @@ public:
|
||||||
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
|
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ int8_t tmp_values[size()];
|
__at_align__ int8_t tmp_values[size()];
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
|
||||||
std::memcpy(ptr, tmp_values, count * sizeof(int8_t));
|
std::memcpy(ptr, tmp_values, count * sizeof(int8_t));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
||||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
// See Note [Do not compile initializers with AVX]
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/native/quantized/affine_quantizer_base.h>
|
#include <ATen/native/quantized/affine_quantizer_base.h>
|
||||||
#include <c10/util/qint32.h>
|
#include <c10/util/qint32.h>
|
||||||
#include <c10/util/qint8.h>
|
#include <c10/util/qint8.h>
|
||||||
|
|
@ -39,7 +39,7 @@ namespace at {
|
||||||
namespace vec {
|
namespace vec {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
|
|
||||||
struct Vectorizedqi {
|
struct Vectorizedqi {
|
||||||
protected:
|
protected:
|
||||||
|
|
@ -53,7 +53,6 @@ struct Vectorizedqi {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#if defined(CPU_CAPABILITY_AVX2)
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__m256i pack_saturate_and_clamp(
|
__m256i pack_saturate_and_clamp(
|
||||||
__m256i first,
|
__m256i first,
|
||||||
|
|
@ -94,7 +93,6 @@ __m256i pack_saturate_and_clamp<uint8_t>(
|
||||||
_mm256_set1_epi8(min_val),
|
_mm256_set1_epi8(min_val),
|
||||||
_mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val)));
|
_mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val)));
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void __attribute__((always_inline)) QuantizeAvx2(
|
inline void __attribute__((always_inline)) QuantizeAvx2(
|
||||||
|
|
@ -103,7 +101,6 @@ inline void __attribute__((always_inline)) QuantizeAvx2(
|
||||||
int len,
|
int len,
|
||||||
float inverse_scale,
|
float inverse_scale,
|
||||||
int64_t zero_point) {
|
int64_t zero_point) {
|
||||||
#if defined(CPU_CAPABILITY_AVX2)
|
|
||||||
constexpr int VLEN = 8;
|
constexpr int VLEN = 8;
|
||||||
constexpr auto min_val = std::numeric_limits<typename T::underlying>::min();
|
constexpr auto min_val = std::numeric_limits<typename T::underlying>::min();
|
||||||
constexpr auto max_val = std::numeric_limits<typename T::underlying>::max();
|
constexpr auto max_val = std::numeric_limits<typename T::underlying>::max();
|
||||||
|
|
@ -212,10 +209,6 @@ inline void __attribute__((always_inline)) QuantizeAvx2(
|
||||||
std::min(std::max(transformed, float(min_val)), float(max_val));
|
std::min(std::max(transformed, float(min_val)), float(max_val));
|
||||||
dst[i] = clipped;
|
dst[i] = clipped;
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
at::native::quantize_vec<T>(
|
|
||||||
1.0f / inverse_scale, zero_point, src, reinterpret_cast<T*>(dst), len);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
|
|
@ -266,11 +259,7 @@ struct Vectorized<c10::qint32> : public Vectorizedqi {
|
||||||
Vectorized<float> zero_point,
|
Vectorized<float> zero_point,
|
||||||
Vectorized<float> scale_zp_premul) const {
|
Vectorized<float> scale_zp_premul) const {
|
||||||
__m256 float_vals = _mm256_cvtepi32_ps(vals);
|
__m256 float_vals = _mm256_cvtepi32_ps(vals);
|
||||||
#if defined(CPU_CAPABILITY_AVX2)
|
|
||||||
return {vec::fmadd(scale, Vectorized<float>(float_vals), scale_zp_premul)};
|
return {vec::fmadd(scale, Vectorized<float>(float_vals), scale_zp_premul)};
|
||||||
#else
|
|
||||||
return {scale * (Vectorized<float>(float_vals) - zero_point)};
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Vectorized<c10::qint32> quantize(
|
static Vectorized<c10::qint32> quantize(
|
||||||
|
|
@ -286,39 +275,11 @@ struct Vectorized<c10::qint32> : public Vectorizedqi {
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<c10::qint32> maximum(Vectorized<c10::qint32> b) const {
|
Vectorized<c10::qint32> maximum(Vectorized<c10::qint32> b) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_max_epi32(vals, b.vals);
|
return _mm256_max_epi32(vals, b.vals);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<int32_t, size()> int_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(int_vals.data()), vals);
|
|
||||||
std::array<int32_t, size()> b_vals;
|
|
||||||
_mm256_storeu_si256(
|
|
||||||
reinterpret_cast<__m256i*>(b_vals.data()), b.vals);
|
|
||||||
std::array<int32_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] = std::max<int32_t>(int_vals[i], b_vals[i]);
|
|
||||||
}
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<c10::qint32> minimum(Vectorized<c10::qint32> b) const {
|
Vectorized<c10::qint32> minimum(Vectorized<c10::qint32> b) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_min_epi32(vals, b.vals);
|
return _mm256_min_epi32(vals, b.vals);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<int32_t, size()> int_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
|
|
||||||
std::array<int32_t, size()> b_vals;
|
|
||||||
_mm256_storeu_si256(
|
|
||||||
reinterpret_cast<__m256i*>(&b_vals), b.vals);
|
|
||||||
std::array<int32_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] = std::min<int32_t>(int_vals[i], b_vals[i]);
|
|
||||||
}
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const {
|
Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const {
|
||||||
|
|
@ -328,65 +289,24 @@ struct Vectorized<c10::qint32> : public Vectorizedqi {
|
||||||
Vectorized<c10::qint32> relu6(
|
Vectorized<c10::qint32> relu6(
|
||||||
Vectorized<c10::qint32> zero_point,
|
Vectorized<c10::qint32> zero_point,
|
||||||
Vectorized<c10::qint32> q_six) {
|
Vectorized<c10::qint32> q_six) {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_min_epi32(
|
return _mm256_min_epi32(
|
||||||
_mm256_max_epi32(vals, zero_point.vals), q_six.vals);
|
_mm256_max_epi32(vals, zero_point.vals), q_six.vals);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<int32_t, size()> int_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
|
|
||||||
std::array<int32_t, size()> zero_point_vals;
|
|
||||||
_mm256_storeu_si256(
|
|
||||||
reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
|
|
||||||
std::array<int32_t,size()> q_six_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
|
|
||||||
std::array<int32_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] = std::min<int32_t>(
|
|
||||||
std::max<int32_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
|
|
||||||
}
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const {
|
int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return {_mm256_sub_epi32(vals, b)};
|
return {_mm256_sub_epi32(vals, b)};
|
||||||
#else
|
|
||||||
std::array<int32_t, size()> int_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
|
|
||||||
std::array<int32_t, size()> b_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&b_vals), b.vals);
|
|
||||||
std::array<int32_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] = int_vals[i] - b_vals[i];
|
|
||||||
}
|
|
||||||
return {_mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals))};
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Vectorized<c10::qint32> requantize_from_int(
|
static Vectorized<c10::qint32> requantize_from_int(
|
||||||
const int_vec_return_type& inp,
|
const int_vec_return_type& inp,
|
||||||
float multiplier,
|
float multiplier,
|
||||||
int32_t zero_point) {
|
int32_t zero_point) {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
__m256 multiplier_v = _mm256_set1_ps(multiplier);
|
__m256 multiplier_v = _mm256_set1_ps(multiplier);
|
||||||
__m256i zero_point_v = _mm256_set1_epi32(zero_point);
|
__m256i zero_point_v = _mm256_set1_epi32(zero_point);
|
||||||
|
|
||||||
__m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v);
|
__m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v);
|
||||||
__m256i rounded = _mm256_cvtps_epi32(scaled);
|
__m256i rounded = _mm256_cvtps_epi32(scaled);
|
||||||
return _mm256_add_epi32(rounded, zero_point_v);
|
return _mm256_add_epi32(rounded, zero_point_v);
|
||||||
#else
|
|
||||||
std::array<int32_t,size()> inp_vals;
|
|
||||||
inp[0].store(inp_vals.data());
|
|
||||||
std::array<int32_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] =
|
|
||||||
nearbyint(static_cast<float>(inp_vals[i]) * multiplier) +
|
|
||||||
zero_point;
|
|
||||||
}
|
|
||||||
return loadu(result_vals.data());
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void dump() const {
|
void dump() const {
|
||||||
|
|
@ -411,43 +331,16 @@ template <>
|
||||||
Vectorized<c10::qint32> inline operator*(
|
Vectorized<c10::qint32> inline operator*(
|
||||||
const Vectorized<c10::qint32>& a,
|
const Vectorized<c10::qint32>& a,
|
||||||
const Vectorized<c10::qint32>& b) {
|
const Vectorized<c10::qint32>& b) {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_mullo_epi32(a, b);
|
return _mm256_mullo_epi32(a, b);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<int32_t, std::decay_t<decltype(a)>::size()> a_vals;
|
|
||||||
std::array<int32_t, std::decay_t<decltype(b)>::size()> b_vals;
|
|
||||||
a.store(a_vals.data());
|
|
||||||
b.store(b_vals.data());
|
|
||||||
std::array<int32_t, std::decay_t<decltype(a)>::size()> result_vals;
|
|
||||||
for (size_t i = 0; i < std::decay_t<decltype(a)>::size(); ++i) {
|
|
||||||
result_vals[i] = a_vals[i] * b_vals[i];
|
|
||||||
}
|
|
||||||
return Vectorized<c10::qint32>::loadu(result_vals.data());
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Vectorized<c10::qint32> inline operator+(
|
Vectorized<c10::qint32> inline operator+(
|
||||||
const Vectorized<c10::qint32>& a,
|
const Vectorized<c10::qint32>& a,
|
||||||
const Vectorized<c10::qint32>& b) {
|
const Vectorized<c10::qint32>& b) {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_add_epi32(a, b);
|
return _mm256_add_epi32(a, b);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<int32_t, std::decay_t<decltype(a)>::size()> a_vals;
|
|
||||||
std::array<int32_t, std::decay_t<decltype(b)>::size()> b_vals;
|
|
||||||
a.store(a_vals.data());
|
|
||||||
b.store(b_vals.data());
|
|
||||||
std::array<int32_t, std::decay_t<decltype(a)>::size()> result_vals;
|
|
||||||
for (size_t i = 0; i < std::decay_t<decltype(a)>::size(); ++i) {
|
|
||||||
result_vals[i] = a_vals[i] + b_vals[i];
|
|
||||||
}
|
|
||||||
return Vectorized<c10::qint32>::loadu(result_vals.data());
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
/*
|
/*
|
||||||
* Convert values from int32 back to int8/uint8
|
* Convert values from int32 back to int8/uint8
|
||||||
*/
|
*/
|
||||||
|
|
@ -493,7 +386,6 @@ __m256i RequantizeAvx2(
|
||||||
xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
|
xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
|
||||||
return xyzw_clamped_v;
|
return xyzw_clamped_v;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
struct Vectorized<c10::qint8> : public Vectorizedqi {
|
struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||||
|
|
@ -544,21 +436,7 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
__m256i cvtepi8_epi32(__m128i epi8_vals) const {
|
__m256i cvtepi8_epi32(__m128i epi8_vals) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_cvtepi8_epi32(epi8_vals);
|
return _mm256_cvtepi8_epi32(epi8_vals);
|
||||||
#else // CPU_CAPABILITY_AVX2
|
|
||||||
__m128i result_data[2];
|
|
||||||
__m128i unpacked1 = _mm_unpacklo_epi8(epi8_vals, epi8_vals);
|
|
||||||
__m128i unpacked2 = _mm_unpacklo_epi16(unpacked1, unpacked1);
|
|
||||||
__m128i shifted1 = _mm_srli_si128(epi8_vals, 4);
|
|
||||||
__m128i shifted2 = _mm_srai_epi32(unpacked2, 24);
|
|
||||||
result_data[0] = shifted2;
|
|
||||||
__m128i unpacked3 = _mm_unpacklo_epi8(shifted1, shifted1);
|
|
||||||
__m128i unpacked4 = _mm_unpacklo_epi16(unpacked3, unpacked3);
|
|
||||||
__m128i shifted3 = _mm_srai_epi32(unpacked4, 24);
|
|
||||||
result_data[1] = shifted3;
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_data));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
@ -576,7 +454,6 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||||
__m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2));
|
__m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2));
|
||||||
__m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3));
|
__m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3));
|
||||||
|
|
||||||
#if defined(CPU_CAPABILITY_AVX2)
|
|
||||||
auto val0 =
|
auto val0 =
|
||||||
vec::fmadd(scale, Vectorized<float>(float_val0), scale_neg_zp_premul);
|
vec::fmadd(scale, Vectorized<float>(float_val0), scale_neg_zp_premul);
|
||||||
auto val1 =
|
auto val1 =
|
||||||
|
|
@ -585,12 +462,6 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||||
vec::fmadd(scale, Vectorized<float>(float_val2), scale_neg_zp_premul);
|
vec::fmadd(scale, Vectorized<float>(float_val2), scale_neg_zp_premul);
|
||||||
auto val3 =
|
auto val3 =
|
||||||
vec::fmadd(scale, Vectorized<float>(float_val3), scale_neg_zp_premul);
|
vec::fmadd(scale, Vectorized<float>(float_val3), scale_neg_zp_premul);
|
||||||
#else
|
|
||||||
auto val0 = scale * (Vectorized<float>(float_val0) - zero_point);
|
|
||||||
auto val1 = scale * (Vectorized<float>(float_val1) - zero_point);
|
|
||||||
auto val2 = scale * (Vectorized<float>(float_val2) - zero_point);
|
|
||||||
auto val3 = scale * (Vectorized<float>(float_val3) - zero_point);
|
|
||||||
#endif
|
|
||||||
return {val0, val1, val2, val3};
|
return {val0, val1, val2, val3};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -607,39 +478,11 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<c10::qint8> maximum(Vectorized<c10::qint8> b) const {
|
Vectorized<c10::qint8> maximum(Vectorized<c10::qint8> b) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_max_epi8(vals, b.vals);
|
return _mm256_max_epi8(vals, b.vals);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<int8_t, size()> int_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
|
|
||||||
std::array<int8_t, size()> b_vals;
|
|
||||||
_mm256_storeu_si256(
|
|
||||||
reinterpret_cast<__m256i*>(&b_vals), b.vals);
|
|
||||||
std::array<int8_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] = std::max<int8_t>(int_vals[i], b_vals[i]);
|
|
||||||
}
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<c10::qint8> minimum(Vectorized<c10::qint8> b) const {
|
Vectorized<c10::qint8> minimum(Vectorized<c10::qint8> b) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_min_epi8(vals, b.vals);
|
return _mm256_min_epi8(vals, b.vals);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<int8_t, size()> int_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
|
|
||||||
std::array<int8_t, size()> b_vals;
|
|
||||||
_mm256_storeu_si256(
|
|
||||||
reinterpret_cast<__m256i*>(&b_vals), b.vals);
|
|
||||||
std::array<int8_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] = std::min<int8_t>(int_vals[i], b_vals[i]);
|
|
||||||
}
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<c10::qint8> relu(Vectorized<c10::qint8> zero_point) const {
|
Vectorized<c10::qint8> relu(Vectorized<c10::qint8> zero_point) const {
|
||||||
|
|
@ -649,29 +492,11 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||||
Vectorized<c10::qint8> relu6(
|
Vectorized<c10::qint8> relu6(
|
||||||
Vectorized<c10::qint8> zero_point,
|
Vectorized<c10::qint8> zero_point,
|
||||||
Vectorized<c10::qint8> q_six) {
|
Vectorized<c10::qint8> q_six) {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_min_epi8(
|
return _mm256_min_epi8(
|
||||||
_mm256_max_epi8(vals, zero_point.vals), q_six.vals);
|
_mm256_max_epi8(vals, zero_point.vals), q_six.vals);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<int8_t, size()> int_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
|
|
||||||
std::array<int8_t, size()> zero_point_vals;
|
|
||||||
_mm256_storeu_si256(
|
|
||||||
reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
|
|
||||||
std::array<int8_t, size()> q_six_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
|
|
||||||
std::array<int8_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] = std::min<int8_t>(
|
|
||||||
std::max<int8_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
|
|
||||||
}
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
|
int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
__m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
|
__m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
|
||||||
__m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
|
__m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
|
||||||
__m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
|
__m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
|
||||||
|
|
@ -701,55 +526,15 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||||
Vectorized<c10::qint32>(res_1),
|
Vectorized<c10::qint32>(res_1),
|
||||||
Vectorized<c10::qint32>(res_2),
|
Vectorized<c10::qint32>(res_2),
|
||||||
Vectorized<c10::qint32>(res_3)};
|
Vectorized<c10::qint32>(res_3)};
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<int8_t, size()> int_vals;
|
|
||||||
store(int_vals.data());
|
|
||||||
std::array<int8_t, size()> b_vals;
|
|
||||||
b.store(b_vals.data());
|
|
||||||
constexpr int elem_per_int_vec = size() / int_num_vecs();
|
|
||||||
int32_t rv[int_num_vecs()][elem_per_int_vec];
|
|
||||||
for (size_t i = 0; i < int_num_vecs(); ++i) {
|
|
||||||
for (size_t j = 0; j < elem_per_int_vec; ++j) {
|
|
||||||
rv[i][j] = static_cast<int32_t>(int_vals[i * elem_per_int_vec + j]) -
|
|
||||||
static_cast<int32_t>(b_vals[i * elem_per_int_vec + j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {Vectorized<c10::qint32>::loadu(rv[0]),
|
|
||||||
Vectorized<c10::qint32>::loadu(rv[1]),
|
|
||||||
Vectorized<c10::qint32>::loadu(rv[2]),
|
|
||||||
Vectorized<c10::qint32>::loadu(rv[3])};
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Vectorized<c10::qint8> requantize_from_int(
|
static Vectorized<c10::qint8> requantize_from_int(
|
||||||
const int_vec_return_type& inp,
|
const int_vec_return_type& inp,
|
||||||
float multiplier,
|
float multiplier,
|
||||||
int32_t zero_point) {
|
int32_t zero_point) {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
__m256 multiplier_v = _mm256_set1_ps(multiplier);
|
__m256 multiplier_v = _mm256_set1_ps(multiplier);
|
||||||
__m256i zero_point_v = _mm256_set1_epi32(zero_point);
|
__m256i zero_point_v = _mm256_set1_epi32(zero_point);
|
||||||
return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v);
|
return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
constexpr int elem_per_int_vec = size() / int_num_vecs();
|
|
||||||
constexpr auto min_val = std::numeric_limits<value_type>::min();
|
|
||||||
constexpr auto max_val = std::numeric_limits<value_type>::max();
|
|
||||||
int32_t rv[int_num_vecs()][elem_per_int_vec];
|
|
||||||
for (size_t i = 0; i < int_num_vecs(); ++i) {
|
|
||||||
inp[i].store(rv[i]);
|
|
||||||
}
|
|
||||||
std::array<int8_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < int_num_vecs(); ++i) {
|
|
||||||
for (size_t j = 0; j < elem_per_int_vec; ++j) {
|
|
||||||
int32_t rounded =
|
|
||||||
nearbyint(static_cast<float>(rv[i][j]) * multiplier) + zero_point;
|
|
||||||
result_vals[i * elem_per_int_vec + j] =
|
|
||||||
std::min<int32_t>(std::max<int32_t>(rounded, min_val), max_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return loadu(result_vals.data());
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void dump() const {
|
void dump() const {
|
||||||
|
|
@ -817,20 +602,7 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
__m256i cvtepu8_epi32(__m128i epu8_vals) const {
|
__m256i cvtepu8_epi32(__m128i epu8_vals) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_cvtepu8_epi32(epu8_vals);
|
return _mm256_cvtepu8_epi32(epu8_vals);
|
||||||
#else // CPU_CAPABILITY_AVX2
|
|
||||||
__m128i result_data[2];
|
|
||||||
__m128i zeros = _mm_setzero_si128();
|
|
||||||
__m128i unpacked1 = _mm_unpacklo_epi8(epu8_vals, zeros);
|
|
||||||
__m128i unpacked2 = _mm_unpacklo_epi16(unpacked1, zeros);
|
|
||||||
result_data[0] = unpacked2;
|
|
||||||
__m128i shifted = _mm_srli_si128(epu8_vals, 4);
|
|
||||||
__m128i unpacked3 = _mm_unpacklo_epi8(shifted, zeros);
|
|
||||||
__m128i unpacked4 = _mm_unpacklo_epi16(unpacked3, zeros);
|
|
||||||
result_data[1] = unpacked4;
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_data));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
@ -848,7 +620,6 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||||
__m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
|
__m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
|
||||||
__m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));
|
__m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));
|
||||||
|
|
||||||
#if defined(CPU_CAPABILITY_AVX2)
|
|
||||||
auto val0 =
|
auto val0 =
|
||||||
vec::fmadd(scale, Vectorized<float>(float_val0), scale_zp_premul);
|
vec::fmadd(scale, Vectorized<float>(float_val0), scale_zp_premul);
|
||||||
auto val1 =
|
auto val1 =
|
||||||
|
|
@ -857,12 +628,6 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||||
vec::fmadd(scale, Vectorized<float>(float_val2), scale_zp_premul);
|
vec::fmadd(scale, Vectorized<float>(float_val2), scale_zp_premul);
|
||||||
auto val3 =
|
auto val3 =
|
||||||
vec::fmadd(scale, Vectorized<float>(float_val3), scale_zp_premul);
|
vec::fmadd(scale, Vectorized<float>(float_val3), scale_zp_premul);
|
||||||
#else
|
|
||||||
auto val0 = scale * (Vectorized<float>(float_val0) - zero_point);
|
|
||||||
auto val1 = scale * (Vectorized<float>(float_val1) - zero_point);
|
|
||||||
auto val2 = scale * (Vectorized<float>(float_val2) - zero_point);
|
|
||||||
auto val3 = scale * (Vectorized<float>(float_val3) - zero_point);
|
|
||||||
#endif
|
|
||||||
return {val0, val1, val2, val3};
|
return {val0, val1, val2, val3};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -879,39 +644,11 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<c10::quint8> maximum(Vectorized<c10::quint8> b) const {
|
Vectorized<c10::quint8> maximum(Vectorized<c10::quint8> b) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_max_epu8(vals, b.vals);
|
return _mm256_max_epu8(vals, b.vals);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<uint8_t, size()> int_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
|
|
||||||
std::array<uint8_t, size()> b_vals;
|
|
||||||
_mm256_storeu_si256(
|
|
||||||
reinterpret_cast<__m256i*>(&b_vals), b.vals);
|
|
||||||
std::array<uint8_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] = std::max<uint8_t>(int_vals[i], b_vals[i]);
|
|
||||||
}
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<c10::quint8> minimum(Vectorized<c10::quint8> b) const {
|
Vectorized<c10::quint8> minimum(Vectorized<c10::quint8> b) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_min_epu8(vals, b.vals);
|
return _mm256_min_epu8(vals, b.vals);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<uint8_t, size()> int_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
|
|
||||||
std::array<uint8_t, size()> b_vals;
|
|
||||||
_mm256_storeu_si256(
|
|
||||||
reinterpret_cast<__m256i*>(&b_vals), b.vals);
|
|
||||||
std::array<uint8_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] = std::min<uint8_t>(int_vals[i], b_vals[i]);
|
|
||||||
}
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<c10::quint8> relu(Vectorized<c10::quint8> zero_point) const {
|
Vectorized<c10::quint8> relu(Vectorized<c10::quint8> zero_point) const {
|
||||||
|
|
@ -921,29 +658,11 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||||
Vectorized<c10::quint8> relu6(
|
Vectorized<c10::quint8> relu6(
|
||||||
Vectorized<c10::quint8> zero_point,
|
Vectorized<c10::quint8> zero_point,
|
||||||
Vectorized<c10::quint8> q_six) {
|
Vectorized<c10::quint8> q_six) {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
return _mm256_min_epu8(
|
return _mm256_min_epu8(
|
||||||
_mm256_max_epu8(vals, zero_point.vals), q_six.vals);
|
_mm256_max_epu8(vals, zero_point.vals), q_six.vals);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<uint8_t, size()> int_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
|
|
||||||
std::array<uint8_t, size()> zero_point_vals;
|
|
||||||
_mm256_storeu_si256(
|
|
||||||
reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
|
|
||||||
std::array<uint8_t, size()> q_six_vals;
|
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
|
|
||||||
std::array<uint8_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < size(); ++i) {
|
|
||||||
result_vals[i] = std::min<uint8_t>(
|
|
||||||
std::max<uint8_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
|
|
||||||
}
|
|
||||||
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
|
int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
__m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
|
__m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
|
||||||
__m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
|
__m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
|
||||||
__m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
|
__m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
|
||||||
|
|
@ -972,55 +691,15 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||||
Vectorized<c10::qint32>(res_1),
|
Vectorized<c10::qint32>(res_1),
|
||||||
Vectorized<c10::qint32>(res_2),
|
Vectorized<c10::qint32>(res_2),
|
||||||
Vectorized<c10::qint32>(res_3)};
|
Vectorized<c10::qint32>(res_3)};
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
std::array<uint8_t, size()> int_vals;
|
|
||||||
std::array<uint8_t, size()> b_vals;
|
|
||||||
store(int_vals.data());
|
|
||||||
b.store(b_vals.data());
|
|
||||||
static constexpr int elem_per_int_vec = size() / int_num_vecs();
|
|
||||||
int32_t rv[int_num_vecs()][elem_per_int_vec];
|
|
||||||
for (size_t i = 0; i < int_num_vecs(); ++i) {
|
|
||||||
for (size_t j = 0; j < elem_per_int_vec; ++j) {
|
|
||||||
rv[i][j] = static_cast<int32_t>(int_vals[i * elem_per_int_vec + j]) -
|
|
||||||
static_cast<int32_t>(b_vals[i * elem_per_int_vec + j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {Vectorized<c10::qint32>::loadu(rv[0]),
|
|
||||||
Vectorized<c10::qint32>::loadu(rv[1]),
|
|
||||||
Vectorized<c10::qint32>::loadu(rv[2]),
|
|
||||||
Vectorized<c10::qint32>::loadu(rv[3])};
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Vectorized<c10::quint8> requantize_from_int(
|
static Vectorized<c10::quint8> requantize_from_int(
|
||||||
const int_vec_return_type& inp,
|
const int_vec_return_type& inp,
|
||||||
float multiplier,
|
float multiplier,
|
||||||
int32_t zero_point) {
|
int32_t zero_point) {
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
__m256 multiplier_v = _mm256_set1_ps(multiplier);
|
__m256 multiplier_v = _mm256_set1_ps(multiplier);
|
||||||
__m256i zero_point_v = _mm256_set1_epi32(zero_point);
|
__m256i zero_point_v = _mm256_set1_epi32(zero_point);
|
||||||
return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v);
|
return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v);
|
||||||
#else
|
|
||||||
// Pray the compiler can autovectorize this
|
|
||||||
constexpr int elem_per_int_vec = size() / int_num_vecs();
|
|
||||||
constexpr auto min_val = std::numeric_limits<value_type>::min();
|
|
||||||
constexpr auto max_val = std::numeric_limits<value_type>::max();
|
|
||||||
int32_t rv[int_num_vecs()][elem_per_int_vec];
|
|
||||||
for (size_t i = 0; i < int_num_vecs(); ++i) {
|
|
||||||
inp[i].store(rv[i]);
|
|
||||||
}
|
|
||||||
std::array<uint8_t, size()> result_vals;
|
|
||||||
for (size_t i = 0; i < int_num_vecs(); ++i) {
|
|
||||||
for (size_t j = 0; j < elem_per_int_vec; ++j) {
|
|
||||||
int32_t rounded =
|
|
||||||
nearbyint(static_cast<float>(rv[i][j]) * multiplier) + zero_point;
|
|
||||||
result_vals[i * elem_per_int_vec + j] =
|
|
||||||
std::min<int32_t>(std::max<int32_t>(rounded, min_val), max_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return loadu(result_vals.data());
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void dump() const {
|
void dump() const {
|
||||||
|
|
@ -1497,6 +1176,5 @@ Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const V
|
||||||
return a.maximum(b);
|
return a.maximum(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
#endif // if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||||
|
|
||||||
}}}
|
}}}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h>
|
#include <ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h>
|
#include <ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h>
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
#include <c10/util/complex.h>
|
#include <c10/util/complex.h>
|
||||||
|
|
||||||
|
|
@ -141,7 +141,7 @@ class Vectorized<ComplexDbl> {
|
||||||
vec_vsx_ld(offset16, reinterpret_cast<const double*>(ptr))};
|
vec_vsx_ld(offset16, reinterpret_cast<const double*>(ptr))};
|
||||||
}
|
}
|
||||||
|
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -153,7 +153,7 @@ class Vectorized<ComplexDbl> {
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(ptr));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(ptr));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(ptr));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(ptr));
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(tmp_values));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(tmp_values));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(tmp_values));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(tmp_values));
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
|
|
@ -165,7 +165,7 @@ class Vectorized<ComplexDbl> {
|
||||||
ComplexDbl& operator[](int idx) = delete;
|
ComplexDbl& operator[](int idx) = delete;
|
||||||
|
|
||||||
Vectorized<ComplexDbl> map(ComplexDbl (*const f)(ComplexDbl)) const {
|
Vectorized<ComplexDbl> map(ComplexDbl (*const f)(ComplexDbl)) const {
|
||||||
__at_align32__ ComplexDbl tmp[size()];
|
__at_align__ ComplexDbl tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
tmp[i] = f(tmp[i]);
|
tmp[i] = f(tmp[i]);
|
||||||
|
|
@ -174,7 +174,7 @@ class Vectorized<ComplexDbl> {
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<ComplexDbl> map(ComplexDbl (*const f)(const ComplexDbl&)) const {
|
Vectorized<ComplexDbl> map(ComplexDbl (*const f)(const ComplexDbl&)) const {
|
||||||
__at_align32__ ComplexDbl tmp[size()];
|
__at_align__ ComplexDbl tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
tmp[i] = f(tmp[i]);
|
tmp[i] = f(tmp[i]);
|
||||||
|
|
@ -455,8 +455,8 @@ class Vectorized<ComplexDbl> {
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<ComplexDbl> pow(const Vectorized<ComplexDbl>& exp) const {
|
Vectorized<ComplexDbl> pow(const Vectorized<ComplexDbl>& exp) const {
|
||||||
__at_align32__ ComplexDbl x_tmp[size()];
|
__at_align__ ComplexDbl x_tmp[size()];
|
||||||
__at_align32__ ComplexDbl y_tmp[size()];
|
__at_align__ ComplexDbl y_tmp[size()];
|
||||||
store(x_tmp);
|
store(x_tmp);
|
||||||
exp.store(y_tmp);
|
exp.store(y_tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
#include <c10/util/complex.h>
|
#include <c10/util/complex.h>
|
||||||
|
|
||||||
|
|
@ -196,7 +196,7 @@ class Vectorized<ComplexFlt> {
|
||||||
vec_vsx_ld(offset16, reinterpret_cast<const float*>(ptr))};
|
vec_vsx_ld(offset16, reinterpret_cast<const float*>(ptr))};
|
||||||
}
|
}
|
||||||
|
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -209,7 +209,7 @@ class Vectorized<ComplexFlt> {
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(ptr));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(ptr));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(ptr));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(ptr));
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(tmp_values));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(tmp_values));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(tmp_values));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(tmp_values));
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
|
|
@ -221,7 +221,7 @@ class Vectorized<ComplexFlt> {
|
||||||
ComplexFlt& operator[](int idx) = delete;
|
ComplexFlt& operator[](int idx) = delete;
|
||||||
|
|
||||||
Vectorized<ComplexFlt> map(ComplexFlt (*const f)(ComplexFlt)) const {
|
Vectorized<ComplexFlt> map(ComplexFlt (*const f)(ComplexFlt)) const {
|
||||||
__at_align32__ ComplexFlt tmp[size()];
|
__at_align__ ComplexFlt tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
tmp[i] = f(tmp[i]);
|
tmp[i] = f(tmp[i]);
|
||||||
|
|
@ -230,7 +230,7 @@ class Vectorized<ComplexFlt> {
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<ComplexFlt> map(ComplexFlt (*const f)(const ComplexFlt&)) const {
|
Vectorized<ComplexFlt> map(ComplexFlt (*const f)(const ComplexFlt&)) const {
|
||||||
__at_align32__ ComplexFlt tmp[size()];
|
__at_align__ ComplexFlt tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
tmp[i] = f(tmp[i]);
|
tmp[i] = f(tmp[i]);
|
||||||
|
|
@ -434,8 +434,8 @@ class Vectorized<ComplexFlt> {
|
||||||
}
|
}
|
||||||
|
|
||||||
Vectorized<ComplexFlt> pow(const Vectorized<ComplexFlt>& exp) const {
|
Vectorized<ComplexFlt> pow(const Vectorized<ComplexFlt>& exp) const {
|
||||||
__at_align32__ ComplexFlt x_tmp[size()];
|
__at_align__ ComplexFlt x_tmp[size()];
|
||||||
__at_align32__ ComplexFlt y_tmp[size()];
|
__at_align__ ComplexFlt y_tmp[size()];
|
||||||
store(x_tmp);
|
store(x_tmp);
|
||||||
exp.store(y_tmp);
|
exp.store(y_tmp);
|
||||||
for (int i = 0; i < size(); i++) {
|
for (int i = 0; i < size(); i++) {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
#include <sleef.h>
|
#include <sleef.h>
|
||||||
|
|
||||||
|
|
@ -169,7 +169,7 @@ class Vectorized<double> {
|
||||||
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
||||||
}
|
}
|
||||||
|
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
||||||
|
|
||||||
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
||||||
|
|
@ -179,7 +179,7 @@ class Vectorized<double> {
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
vec_vsx_st(_vec0, offset0, tmp_values);
|
vec_vsx_st(_vec0, offset0, tmp_values);
|
||||||
vec_vsx_st(_vec1, offset16, tmp_values);
|
vec_vsx_st(_vec1, offset16, tmp_values);
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
#include <sleef.h>
|
#include <sleef.h>
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
@ -180,7 +180,7 @@ class Vectorized<float> {
|
||||||
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
||||||
}
|
}
|
||||||
|
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
||||||
|
|
||||||
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
||||||
|
|
@ -190,7 +190,7 @@ class Vectorized<float> {
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
vec_vsx_st(_vec0, offset0, tmp_values);
|
vec_vsx_st(_vec0, offset0, tmp_values);
|
||||||
vec_vsx_st(_vec1, offset16, tmp_values);
|
vec_vsx_st(_vec1, offset16, tmp_values);
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace vec {
|
namespace vec {
|
||||||
|
|
@ -269,7 +269,7 @@ class Vectorized<int16_t> {
|
||||||
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
||||||
}
|
}
|
||||||
|
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
||||||
|
|
||||||
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
||||||
|
|
@ -279,7 +279,7 @@ class Vectorized<int16_t> {
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
vec_vsx_st(_vec0, offset0, tmp_values);
|
vec_vsx_st(_vec0, offset0, tmp_values);
|
||||||
vec_vsx_st(_vec1, offset16, tmp_values);
|
vec_vsx_st(_vec1, offset16, tmp_values);
|
||||||
std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace vec {
|
namespace vec {
|
||||||
|
|
@ -199,7 +199,7 @@ class Vectorized<int32_t> {
|
||||||
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
||||||
}
|
}
|
||||||
|
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
||||||
|
|
||||||
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
||||||
|
|
@ -209,7 +209,7 @@ class Vectorized<int32_t> {
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
vec_vsx_st(_vec0, offset0, tmp_values);
|
vec_vsx_st(_vec0, offset0, tmp_values);
|
||||||
vec_vsx_st(_vec1, offset16, tmp_values);
|
vec_vsx_st(_vec1, offset16, tmp_values);
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace vec {
|
namespace vec {
|
||||||
|
|
@ -148,7 +148,7 @@ class Vectorized<int64_t> {
|
||||||
(vint64)vec_vsx_ld(offset16, dptr)};
|
(vint64)vec_vsx_ld(offset16, dptr)};
|
||||||
}
|
}
|
||||||
|
|
||||||
__at_align32__ double tmp_values[size()];
|
__at_align__ double tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -161,7 +161,7 @@ class Vectorized<int64_t> {
|
||||||
vec_vsx_st((vfloat64)_vec0, offset0, dptr);
|
vec_vsx_st((vfloat64)_vec0, offset0, dptr);
|
||||||
vec_vsx_st((vfloat64)_vec1, offset16, dptr);
|
vec_vsx_st((vfloat64)_vec1, offset16, dptr);
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ double tmp_values[size()];
|
__at_align__ double tmp_values[size()];
|
||||||
vec_vsx_st((vfloat64)_vec0, offset0, tmp_values);
|
vec_vsx_st((vfloat64)_vec0, offset0, tmp_values);
|
||||||
vec_vsx_st((vfloat64)_vec1, offset16, tmp_values);
|
vec_vsx_st((vfloat64)_vec1, offset16, tmp_values);
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
#include <c10/util/qint32.h>
|
#include <c10/util/qint32.h>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
@ -81,7 +81,7 @@ struct Vectorized<c10::qint32> {
|
||||||
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
||||||
}
|
}
|
||||||
|
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
||||||
|
|
||||||
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
||||||
|
|
@ -91,7 +91,7 @@ struct Vectorized<c10::qint32> {
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
vec_vsx_st(_vec0, offset0, tmp_values);
|
vec_vsx_st(_vec0, offset0, tmp_values);
|
||||||
vec_vsx_st(_vec1, offset16, tmp_values);
|
vec_vsx_st(_vec1, offset16, tmp_values);
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
#include <c10/util/qint8.h>
|
#include <c10/util/qint8.h>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
@ -91,7 +91,7 @@ struct Vectorized<c10::qint8> {
|
||||||
vec_vsx_ld(offset0, reinterpret_cast<const vint8*>(ptr)),
|
vec_vsx_ld(offset0, reinterpret_cast<const vint8*>(ptr)),
|
||||||
vec_vsx_ld(offset16, reinterpret_cast<const vint8*>(ptr))};
|
vec_vsx_ld(offset16, reinterpret_cast<const vint8*>(ptr))};
|
||||||
}
|
}
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
||||||
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
||||||
}
|
}
|
||||||
|
|
@ -100,7 +100,7 @@ struct Vectorized<c10::qint8> {
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
vec_vsx_st(_vec0, offset0, tmp_values);
|
vec_vsx_st(_vec0, offset0, tmp_values);
|
||||||
vec_vsx_st(_vec1, offset16, tmp_values);
|
vec_vsx_st(_vec1, offset16, tmp_values);
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/cpu/vec/vec256/vec256_base.h>
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
|
||||||
#include <c10/util/quint8.h>
|
#include <c10/util/quint8.h>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
@ -92,7 +92,7 @@ struct Vectorized<c10::quint8> {
|
||||||
vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
|
vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
|
||||||
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
|
||||||
}
|
}
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
|
||||||
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
|
||||||
}
|
}
|
||||||
|
|
@ -101,7 +101,7 @@ struct Vectorized<c10::quint8> {
|
||||||
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
|
||||||
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
|
||||||
} else if (count > 0) {
|
} else if (count > 0) {
|
||||||
__at_align32__ value_type tmp_values[size()];
|
__at_align__ value_type tmp_values[size()];
|
||||||
vec_vsx_st(_vec0, offset0, tmp_values);
|
vec_vsx_st(_vec0, offset0, tmp_values);
|
||||||
vec_vsx_st(_vec1, offset16, tmp_values);
|
vec_vsx_st(_vec1, offset16, tmp_values);
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
|
|
||||||
using vbool8 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char;
|
using vbool8 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char;
|
||||||
using vbool16 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short;
|
using vbool16 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short;
|
||||||
|
|
|
||||||
195
aten/src/ATen/cpu/vec/vec512/vec512.h
Normal file
195
aten/src/ATen/cpu/vec/vec512/vec512.h
Normal file
|
|
@ -0,0 +1,195 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
|
|
||||||
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
|
#include <ATen/cpu/vec/vec512/vec512_float.h>
|
||||||
|
#include <ATen/cpu/vec/vec512/vec512_bfloat16.h>
|
||||||
|
#include <ATen/cpu/vec/vec512/vec512_double.h>
|
||||||
|
#include <ATen/cpu/vec/vec512/vec512_int.h>
|
||||||
|
#include <ATen/cpu/vec/vec512/vec512_qint.h>
|
||||||
|
#include <ATen/cpu/vec/vec512/vec512_complex_float.h>
|
||||||
|
#include <ATen/cpu/vec/vec512/vec512_complex_double.h>
|
||||||
|
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
namespace vec {
|
||||||
|
|
||||||
|
// See Note [Acceptable use of anonymous namespace in header]
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
|
||||||
|
stream << val.val_;
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
|
||||||
|
stream << static_cast<int>(val.val_);
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
|
||||||
|
stream << static_cast<unsigned int>(val.val_);
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||||
|
T buf[Vectorized<T>::size()];
|
||||||
|
vec.store(buf);
|
||||||
|
stream << "vec[";
|
||||||
|
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
||||||
|
if (i != 0) {
|
||||||
|
stream << ", ";
|
||||||
|
}
|
||||||
|
stream << buf[i];
|
||||||
|
}
|
||||||
|
stream << "]";
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
|
||||||
|
return _mm512_castpd_ps(src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
|
||||||
|
return _mm512_castps_pd(src);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
template<int64_t scale = 1>
|
||||||
|
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
|
||||||
|
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
|
||||||
|
return _mm512_i64gather_pd(vindex, base_addr, scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int64_t scale = 1>
|
||||||
|
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
|
||||||
|
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
|
||||||
|
return _mm512_i32gather_ps(vindex, base_addr, scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
template<int64_t scale = 1>
|
||||||
|
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
|
||||||
|
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
|
||||||
|
const Vectorized<int64_t>& vindex, const Vectorized<double>& mask) {
|
||||||
|
auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF));
|
||||||
|
auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ);
|
||||||
|
return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int64_t scale = 1>
|
||||||
|
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
|
||||||
|
inline mask_gather(const Vectorized<float>& src, const float* base_addr,
|
||||||
|
const Vectorized<int32_t>& vindex, const Vectorized<float>& mask) {
|
||||||
|
auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF));
|
||||||
|
auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ);
|
||||||
|
return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
template<>
|
||||||
|
Vectorized<int64_t>
|
||||||
|
inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
|
||||||
|
return _mm512_cvtpd_epi64(src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
Vectorized<int32_t>
|
||||||
|
inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
|
||||||
|
return _mm512_cvttps_epi32(src);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::pair<Vectorized<double>, Vectorized<double>>
|
||||||
|
inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
// inputs:
|
||||||
|
// a = {a0, a1, a3, a3, a4, a5, a6, a7}
|
||||||
|
// b = {b0, b1, b2, b3, b4, b5, b6, b7}
|
||||||
|
// group cols crossing lanes:
|
||||||
|
// return {a0, b0, a1, b1, a2, b2, a3, b3}
|
||||||
|
// {a4, b4, a5, b5, a6, b6, a7, b7}
|
||||||
|
__m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0);
|
||||||
|
__m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4);
|
||||||
|
return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
|
||||||
|
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::pair<Vectorized<float>, Vectorized<float>>
|
||||||
|
inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
// inputs:
|
||||||
|
// a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
|
||||||
|
// b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
|
||||||
|
//
|
||||||
|
// return:
|
||||||
|
// {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
|
||||||
|
// {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
|
||||||
|
__m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4,
|
||||||
|
19, 3, 18, 2, 17, 1, 16, 0);
|
||||||
|
__m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12,
|
||||||
|
27, 11, 26, 10, 25, 9, 24, 8);
|
||||||
|
return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
|
||||||
|
_mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::pair<Vectorized<double>, Vectorized<double>>
|
||||||
|
inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
// inputs:
|
||||||
|
// a = {a0, b0, a1, b1, a2, b2, a3, b3}
|
||||||
|
// b = {a4, b4, a5, b5, a6, b6, a7, b7}
|
||||||
|
// output:
|
||||||
|
// return {a0, a1, a2, a3, a4, a5, a6, a7}
|
||||||
|
// {b0, b1, b2, b3, b4, b5, b6, b7}
|
||||||
|
// The members of indices have been written in binary format for better understandability
|
||||||
|
__m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0);
|
||||||
|
__m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1);
|
||||||
|
|
||||||
|
return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
|
||||||
|
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::pair<Vectorized<float>, Vectorized<float>>
|
||||||
|
inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
// inputs:
|
||||||
|
// a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
|
||||||
|
// b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
|
||||||
|
// output:
|
||||||
|
// return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
|
||||||
|
// {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
|
||||||
|
__m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16,
|
||||||
|
14, 12, 10, 8, 6, 4, 2, 0);
|
||||||
|
__m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17,
|
||||||
|
15, 13, 11, 9, 7, 5, 3, 1);
|
||||||
|
|
||||||
|
return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
|
||||||
|
_mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
|
||||||
|
}}}
|
||||||
879
aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
Normal file
879
aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
Normal file
|
|
@ -0,0 +1,879 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
#include <sleef.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
namespace vec {
|
||||||
|
// See Note [Acceptable use of anonymous namespace in header]
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
|
||||||
|
static inline void cvtbf16_fp32(const __m256i& a, __m512& o) {
|
||||||
|
o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
|
||||||
|
__m256i lo = _mm512_extracti32x8_epi32(a, 0);
|
||||||
|
__m256i hi = _mm512_extracti32x8_epi32(a, 1);
|
||||||
|
cvtbf16_fp32(lo, o1);
|
||||||
|
cvtbf16_fp32(hi, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) {
|
||||||
|
__m512i lo = _mm512_castps_si512(a);
|
||||||
|
__m512i hi = _mm512_castps_si512(b);
|
||||||
|
__m512i nan = _mm512_set1_epi32(0xffff);
|
||||||
|
auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
|
||||||
|
auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q);
|
||||||
|
__m512i ones = _mm512_set1_epi32(0x1);
|
||||||
|
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
|
||||||
|
// uint32_t lsb = (input >> 16) & 1;
|
||||||
|
auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones);
|
||||||
|
auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones);
|
||||||
|
// uint32_t rounding_bias = 0x7fff + lsb;
|
||||||
|
t_lo = _mm512_add_epi32(t_lo, vec_bias);
|
||||||
|
t_hi = _mm512_add_epi32(t_hi, vec_bias);
|
||||||
|
// input += rounding_bias;
|
||||||
|
t_lo = _mm512_add_epi32(t_lo, lo);
|
||||||
|
t_hi = _mm512_add_epi32(t_hi, hi);
|
||||||
|
// input = input >> 16;
|
||||||
|
t_lo = _mm512_srli_epi32(t_lo, 16);
|
||||||
|
t_hi = _mm512_srli_epi32(t_hi, 16);
|
||||||
|
// Check NaN before converting back to bf16
|
||||||
|
t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo);
|
||||||
|
t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi);
|
||||||
|
|
||||||
|
t_lo = _mm512_packus_epi32(t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4]
|
||||||
|
__m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
|
||||||
|
return _mm512_permutexvar_epi64(idx, t_lo);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __m512i merge_compare_result(const __m512& a, const __m512& b) {
|
||||||
|
__m512i lo = _mm512_castps_si512(a);
|
||||||
|
__m512i hi = _mm512_castps_si512(b);
|
||||||
|
lo = _mm512_srli_epi32(lo, 16);
|
||||||
|
hi = _mm512_srli_epi32(hi, 16);
|
||||||
|
auto out = _mm512_packus_epi32(lo, hi);
|
||||||
|
__m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
|
||||||
|
return _mm512_permutexvar_epi64(idx, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> class Vectorized<BFloat16> {
|
||||||
|
private:
|
||||||
|
__m512i values;
|
||||||
|
public:
|
||||||
|
using value_type = uint16_t;
|
||||||
|
using size_type = int;
|
||||||
|
static constexpr size_type size() {
|
||||||
|
return 32;
|
||||||
|
}
|
||||||
|
Vectorized() {}
|
||||||
|
Vectorized(__m512i v) : values(v) {}
|
||||||
|
Vectorized(BFloat16 val) {
|
||||||
|
value_type uw = val.x;
|
||||||
|
values = _mm512_set1_epi16(uw);
|
||||||
|
}
|
||||||
|
Vectorized(BFloat16 val1, BFloat16 val2, BFloat16 val3, BFloat16 val4,
|
||||||
|
BFloat16 val5, BFloat16 val6, BFloat16 val7, BFloat16 val8,
|
||||||
|
BFloat16 val9, BFloat16 val10, BFloat16 val11, BFloat16 val12,
|
||||||
|
BFloat16 val13, BFloat16 val14, BFloat16 val15, BFloat16 val16,
|
||||||
|
BFloat16 val17, BFloat16 val18, BFloat16 val19, BFloat16 val20,
|
||||||
|
BFloat16 val21, BFloat16 val22, BFloat16 val23, BFloat16 val24,
|
||||||
|
BFloat16 val25, BFloat16 val26, BFloat16 val27, BFloat16 val28,
|
||||||
|
BFloat16 val29, BFloat16 val30, BFloat16 val31, BFloat16 val32) {
|
||||||
|
values = _mm512_set_epi16(
|
||||||
|
val32.x, val31.x, val30.x, val29.x, val28.x, val27.x, val26.x, val25.x,
|
||||||
|
val24.x, val23.x, val22.x, val21.x, val20.x, val19.x, val18.x, val17.x,
|
||||||
|
val16.x, val15.x, val14.x, val13.x, val12.x, val11.x, val10.x, val9.x,
|
||||||
|
val8.x, val7.x, val6.x, val5.x, val4.x, val3.x, val2.x, val1.x);
|
||||||
|
}
|
||||||
|
operator __m512i() const {
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
BFloat16& operator[](int idx) = delete;
|
||||||
|
const BFloat16& operator[](int idx) const = delete;
|
||||||
|
int zero_mask() const {
|
||||||
|
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
|
||||||
|
return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0));
|
||||||
|
}
|
||||||
|
static Vectorized<BFloat16> loadu(const void* ptr) {
|
||||||
|
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
|
||||||
|
}
|
||||||
|
static Vectorized<BFloat16> loadu(const void* ptr, int16_t count) {
|
||||||
|
__at_align__ int16_t tmp_values[size()];
|
||||||
|
std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
|
||||||
|
return loadu(tmp_values);
|
||||||
|
}
|
||||||
|
void store(void* ptr, int count = size()) const {
|
||||||
|
if (count == size()) {
|
||||||
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
|
||||||
|
} else if (count > 0) {
|
||||||
|
__at_align__ int16_t tmp_values[size()];
|
||||||
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values);
|
||||||
|
std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <int64_t mask>
|
||||||
|
static Vectorized<BFloat16> blend(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
|
__at_align__ int16_t tmp_values[size()];
|
||||||
|
a.store(tmp_values);
|
||||||
|
if (mask & 0x01)
|
||||||
|
tmp_values[0] = b.values[31];
|
||||||
|
if (mask & 0x02)
|
||||||
|
tmp_values[1] = b.values[30];
|
||||||
|
if (mask & 0x04)
|
||||||
|
tmp_values[2] = b.values[29];
|
||||||
|
if (mask & 0x08)
|
||||||
|
tmp_values[3] = b.values[28];
|
||||||
|
if (mask & 0x10)
|
||||||
|
tmp_values[4] = b.values[27];
|
||||||
|
if (mask & 0x20)
|
||||||
|
tmp_values[5] = b.values[26];
|
||||||
|
if (mask & 0x40)
|
||||||
|
tmp_values[6] = b.values[25];
|
||||||
|
if (mask & 0x80)
|
||||||
|
tmp_values[7] = b.values[24];
|
||||||
|
if (mask & 0x100)
|
||||||
|
tmp_values[8] = b.values[23];
|
||||||
|
if (mask & 0x200)
|
||||||
|
tmp_values[9] = b.values[22];
|
||||||
|
if (mask & 0x400)
|
||||||
|
tmp_values[10] = b.values[21];
|
||||||
|
if (mask & 0x800)
|
||||||
|
tmp_values[11] = b.values[20];
|
||||||
|
if (mask & 0x1000)
|
||||||
|
tmp_values[12] = b.values[19];
|
||||||
|
if (mask & 0x2000)
|
||||||
|
tmp_values[13] = b.values[18];
|
||||||
|
if (mask & 0x4000)
|
||||||
|
tmp_values[14] = b.values[17];
|
||||||
|
if (mask & 0x8000)
|
||||||
|
tmp_values[15] = b.values[16];
|
||||||
|
if (mask & 0x10000)
|
||||||
|
tmp_values[16] = b.values[15];
|
||||||
|
if (mask & 0x20000)
|
||||||
|
tmp_values[17] = b.values[14];
|
||||||
|
if (mask & 0x40000)
|
||||||
|
tmp_values[18] = b.values[13];
|
||||||
|
if (mask & 0x80000)
|
||||||
|
tmp_values[19] = b.values[12];
|
||||||
|
if (mask & 0x100000)
|
||||||
|
tmp_values[20] = b.values[11];
|
||||||
|
if (mask & 0x200000)
|
||||||
|
tmp_values[21] = b.values[10];
|
||||||
|
if (mask & 0x400000)
|
||||||
|
tmp_values[22] = b.values[9];
|
||||||
|
if (mask & 0x800000)
|
||||||
|
tmp_values[23] = b.values[8];
|
||||||
|
if (mask & 0x1000000)
|
||||||
|
tmp_values[24] = b.values[7];
|
||||||
|
if (mask & 0x2000000)
|
||||||
|
tmp_values[25] = b.values[6];
|
||||||
|
if (mask & 0x4000000)
|
||||||
|
tmp_values[26] = b.values[5];
|
||||||
|
if (mask & 0x8000000)
|
||||||
|
tmp_values[27] = b.values[4];
|
||||||
|
if (mask & 0x10000000)
|
||||||
|
tmp_values[28] = b.values[3];
|
||||||
|
if (mask & 0x20000000)
|
||||||
|
tmp_values[29] = b.values[2];
|
||||||
|
if (mask & 0x40000000)
|
||||||
|
tmp_values[30] = b.values[1];
|
||||||
|
if (mask & 0x80000000)
|
||||||
|
tmp_values[31] = b.values[0];
|
||||||
|
return loadu(tmp_values);
|
||||||
|
}
|
||||||
|
static Vectorized<BFloat16> blendv(const Vectorized<BFloat16>& a,
|
||||||
|
const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& mask) {
|
||||||
|
auto all_ones = _mm512_set1_epi16(0xFFFF);
|
||||||
|
auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ);
|
||||||
|
return _mm512_mask_blend_epi16(mask_, a.values, b.values);
|
||||||
|
}
|
||||||
|
template<typename step_t>
|
||||||
|
static Vectorized<BFloat16> arange(BFloat16 base = 0.f, step_t step = static_cast<step_t>(1)) {
|
||||||
|
return Vectorized<BFloat16>(
|
||||||
|
base, base + step, base + 2 * step, base + 3 * step,
|
||||||
|
base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
|
||||||
|
base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
|
||||||
|
base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
|
||||||
|
base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
|
||||||
|
base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
|
||||||
|
base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
|
||||||
|
base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step);
|
||||||
|
}
|
||||||
|
static Vectorized<BFloat16> set(const Vectorized<BFloat16>& a,
|
||||||
|
const Vectorized<BFloat16>& b, int64_t count = size()) {
|
||||||
|
switch (count) {
|
||||||
|
case 0:
|
||||||
|
return a;
|
||||||
|
case 1:
|
||||||
|
return blend<1>(a, b);
|
||||||
|
case 2:
|
||||||
|
return blend<3>(a, b);
|
||||||
|
case 3:
|
||||||
|
return blend<7>(a, b);
|
||||||
|
case 4:
|
||||||
|
return blend<15>(a, b);
|
||||||
|
case 5:
|
||||||
|
return blend<31>(a, b);
|
||||||
|
case 6:
|
||||||
|
return blend<63>(a, b);
|
||||||
|
case 7:
|
||||||
|
return blend<127>(a, b);
|
||||||
|
case 8:
|
||||||
|
return blend<255>(a, b);
|
||||||
|
case 9:
|
||||||
|
return blend<511>(a, b);
|
||||||
|
case 10:
|
||||||
|
return blend<1023>(a, b);
|
||||||
|
case 11:
|
||||||
|
return blend<2047>(a, b);
|
||||||
|
case 12:
|
||||||
|
return blend<4095>(a, b);
|
||||||
|
case 13:
|
||||||
|
return blend<8191>(a, b);
|
||||||
|
case 14:
|
||||||
|
return blend<16383>(a, b);
|
||||||
|
case 15:
|
||||||
|
return blend<32767>(a, b);
|
||||||
|
case 16:
|
||||||
|
return blend<65535>(a, b);
|
||||||
|
case 17:
|
||||||
|
return blend<131071>(a, b);
|
||||||
|
case 18:
|
||||||
|
return blend<262143>(a, b);
|
||||||
|
case 19:
|
||||||
|
return blend<524287>(a, b);
|
||||||
|
case 20:
|
||||||
|
return blend<1048575>(a, b);
|
||||||
|
case 21:
|
||||||
|
return blend<2097151>(a, b);
|
||||||
|
case 22:
|
||||||
|
return blend<4194303>(a, b);
|
||||||
|
case 23:
|
||||||
|
return blend<8388607>(a, b);
|
||||||
|
case 24:
|
||||||
|
return blend<16777215>(a, b);
|
||||||
|
case 25:
|
||||||
|
return blend<33554431>(a, b);
|
||||||
|
case 26:
|
||||||
|
return blend<67108863>(a, b);
|
||||||
|
case 27:
|
||||||
|
return blend<134217727>(a, b);
|
||||||
|
case 28:
|
||||||
|
return blend<268435455>(a, b);
|
||||||
|
case 29:
|
||||||
|
return blend<536870911>(a, b);
|
||||||
|
case 30:
|
||||||
|
return blend<1073741823>(a, b);
|
||||||
|
case 31:
|
||||||
|
return blend<2147483647>(a, b);
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> map(const __m512 (*const vop)(__m512)) const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
const auto o1 = vop(lo);
|
||||||
|
const auto o2 = vop(hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> abs() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
const auto mask = _mm512_set1_ps(-0.f);
|
||||||
|
const auto o1 = _mm512_andnot_ps(mask, lo);
|
||||||
|
const auto o2 = _mm512_andnot_ps(mask, hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> angle() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
auto angle_lambda = [](__m512 values) {
|
||||||
|
const auto zero_vec = _mm512_set1_ps(0.f);
|
||||||
|
const auto nan_vec = _mm512_set1_ps(NAN);
|
||||||
|
const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
|
||||||
|
const auto non_nan_mask_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec),
|
||||||
|
not_nan_mask, 0xFFFFFFFF);
|
||||||
|
const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(non_nan_mask_vec),
|
||||||
|
zero_vec, _CMP_EQ_OQ);
|
||||||
|
const auto pi = _mm512_set1_ps(c10::pi<float>);
|
||||||
|
|
||||||
|
const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
|
||||||
|
auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
|
||||||
|
angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
|
||||||
|
return angle;
|
||||||
|
};
|
||||||
|
auto o1 = angle_lambda(lo);
|
||||||
|
auto o2 = angle_lambda(hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> real() const {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> imag() const {
|
||||||
|
return _mm512_set1_epi16(0);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> conj() const {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> acos() const {
|
||||||
|
return map(Sleef_acosf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> asin() const {
|
||||||
|
return map(Sleef_asinf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> atan() const {
|
||||||
|
return map(Sleef_atanf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> atan2(const Vectorized<BFloat16> &b) const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
__m512 b1, b2;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
cvtbf16_fp32(b.values, b1, b2);
|
||||||
|
auto o1 = Sleef_atan2f16_u10(lo, b1);
|
||||||
|
auto o2 = Sleef_atan2f16_u10(hi, b2);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> copysign(const Vectorized<BFloat16> &sign) const {
|
||||||
|
// copy sign bit (0x8000) from sign and remaining bits from values
|
||||||
|
__m512i mask_value = _mm512_set1_epi32(~0x80008000);
|
||||||
|
__m512i mask_signbit = _mm512_set1_epi32(0x80008000);
|
||||||
|
return Vectorized<BFloat16>(
|
||||||
|
_mm512_or_si512(
|
||||||
|
_mm512_and_si512(values, mask_value),
|
||||||
|
_mm512_and_si512(sign, mask_signbit)));
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> erf() const {
|
||||||
|
return map(Sleef_erff16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> erfc() const {
|
||||||
|
return map(Sleef_erfcf16_u15);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> erfinv() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||||
|
for (int64_t i = 0; i < size() / 2; i++) {
|
||||||
|
tmp1[i] = calc_erfinv(tmp1[i]);
|
||||||
|
tmp2[i] = calc_erfinv(tmp2[i]);
|
||||||
|
}
|
||||||
|
auto o1 = _mm512_loadu_ps(tmp1);
|
||||||
|
auto o2 = _mm512_loadu_ps(tmp2);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> exp() const {
|
||||||
|
return map(Sleef_expf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> expm1() const {
|
||||||
|
return map(Sleef_expm1f16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> fmod(const Vectorized<BFloat16> & q) const {
|
||||||
|
__m512 x_lo, x_hi;
|
||||||
|
cvtbf16_fp32(values, x_lo, x_hi);
|
||||||
|
__m512 q_lo, q_hi;
|
||||||
|
cvtbf16_fp32(q.values, q_lo, q_hi);
|
||||||
|
auto o1 = Sleef_fmodf16(x_lo, q_lo);
|
||||||
|
auto o2 = Sleef_fmodf16(x_hi, q_hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> hypot(const Vectorized<BFloat16> &b) const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
__m512 b1, b2;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
cvtbf16_fp32(b.values, b1, b2);
|
||||||
|
auto o1 = Sleef_hypotf16_u05(lo, b1);
|
||||||
|
auto o2 = Sleef_hypotf16_u05(hi, b2);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> i0() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||||
|
for (int64_t i = 0; i < size() / 2; i++) {
|
||||||
|
tmp1[i] = calc_i0(tmp1[i]);
|
||||||
|
tmp2[i] = calc_i0(tmp2[i]);
|
||||||
|
}
|
||||||
|
auto o1 = _mm512_loadu_ps(tmp1);
|
||||||
|
auto o2 = _mm512_loadu_ps(tmp2);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> i0e() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
constexpr auto sz = size();
|
||||||
|
__at_align__ float tmp1[sz / 2], tmp2[sz / 2];
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||||
|
|
||||||
|
for (auto i = decltype(sz){0}; i < sz / 2; i++) {
|
||||||
|
tmp1[i] = calc_i0e(tmp1[i]);
|
||||||
|
tmp2[i] = calc_i0e(tmp2[i]);
|
||||||
|
}
|
||||||
|
const auto o1 = _mm512_loadu_ps(tmp1);
|
||||||
|
const auto o2 = _mm512_loadu_ps(tmp2);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> igamma(const Vectorized<BFloat16> &x) const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
__m512 xlo, xhi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
cvtbf16_fp32(x.values, xlo, xhi);
|
||||||
|
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||||
|
__at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
|
||||||
|
for (int64_t i = 0; i < size() / 2; ++i) {
|
||||||
|
tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
|
||||||
|
tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
|
||||||
|
}
|
||||||
|
auto o1 = _mm512_loadu_ps(tmp1);
|
||||||
|
auto o2 = _mm512_loadu_ps(tmp2);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> igammac(const Vectorized<BFloat16> &x) const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
__m512 xlo, xhi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
cvtbf16_fp32(x.values, xlo, xhi);
|
||||||
|
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||||
|
__at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
|
||||||
|
for (int64_t i = 0; i < size() / 2; ++i) {
|
||||||
|
tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]);
|
||||||
|
tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]);
|
||||||
|
}
|
||||||
|
auto o1 = _mm512_loadu_ps(tmp1);
|
||||||
|
auto o2 = _mm512_loadu_ps(tmp2);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> log() const {
|
||||||
|
return map(Sleef_logf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> log2() const {
|
||||||
|
return map(Sleef_log2f16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> log10() const {
|
||||||
|
return map(Sleef_log10f16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> log1p() const {
|
||||||
|
return map(Sleef_log1pf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> frac() const;
|
||||||
|
Vectorized<BFloat16> sin() const {
|
||||||
|
return map(Sleef_sinf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> sinh() const {
|
||||||
|
return map(Sleef_sinhf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> cos() const {
|
||||||
|
return map(Sleef_cosf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> cosh() const {
|
||||||
|
return map(Sleef_coshf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> ceil() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
auto o1 = _mm512_ceil_ps(lo);
|
||||||
|
auto o2 = _mm512_ceil_ps(hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> floor() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
auto o1 = _mm512_floor_ps(lo);
|
||||||
|
auto o2 = _mm512_floor_ps(hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> neg() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
auto mask = _mm512_set1_ps(-0.f);
|
||||||
|
auto o1 = _mm512_xor_ps(mask, lo);
|
||||||
|
auto o2 = _mm512_xor_ps(mask, hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> round() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||||
|
auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> tan() const {
|
||||||
|
return map(Sleef_tanf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> tanh() const {
|
||||||
|
return map(Sleef_tanhf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> trunc() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
|
||||||
|
auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> lgamma() const {
|
||||||
|
return map(Sleef_lgammaf16_u10);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> sqrt() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
auto o1 = _mm512_sqrt_ps(lo);
|
||||||
|
auto o2 = _mm512_sqrt_ps(hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> reciprocal() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
auto ones = _mm512_set1_ps(1);
|
||||||
|
auto o1 = _mm512_div_ps(ones, lo);
|
||||||
|
auto o2 = _mm512_div_ps(ones, hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> rsqrt() const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
auto ones = _mm512_set1_ps(1);
|
||||||
|
auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo));
|
||||||
|
auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi));
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> pow(const Vectorized<BFloat16> &b) const {
|
||||||
|
__m512 lo, hi;
|
||||||
|
__m512 b1, b2;
|
||||||
|
cvtbf16_fp32(values, lo, hi);
|
||||||
|
cvtbf16_fp32(b.values, b1, b2);
|
||||||
|
auto o1 = Sleef_powf16_u10(lo, b1);
|
||||||
|
auto o2 = Sleef_powf16_u10(hi, b2);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> inline operator>(const Vectorized<BFloat16>& other) const;
|
||||||
|
Vectorized<BFloat16> inline operator<(const Vectorized<BFloat16>& other) const;
|
||||||
|
Vectorized<BFloat16> inline operator>=(const Vectorized<BFloat16>& other) const;
|
||||||
|
Vectorized<BFloat16> inline operator<=(const Vectorized<BFloat16>& other) const;
|
||||||
|
Vectorized<BFloat16> inline operator==(const Vectorized<BFloat16>& other) const;
|
||||||
|
Vectorized<BFloat16> inline operator!=(const Vectorized<BFloat16>& other) const;
|
||||||
|
|
||||||
|
Vectorized<BFloat16> eq(const Vectorized<BFloat16>& other) const;
|
||||||
|
Vectorized<BFloat16> ne(const Vectorized<BFloat16>& other) const;
|
||||||
|
Vectorized<BFloat16> gt(const Vectorized<BFloat16>& other) const;
|
||||||
|
Vectorized<BFloat16> ge(const Vectorized<BFloat16>& other) const;
|
||||||
|
Vectorized<BFloat16> lt(const Vectorized<BFloat16>& other) const;
|
||||||
|
Vectorized<BFloat16> le(const Vectorized<BFloat16>& other) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Op>
|
||||||
|
Vectorized<BFloat16> static inline bfloat16_binary_op_as_fp32(const Vectorized<BFloat16>& a,
|
||||||
|
const Vectorized<BFloat16>& b, Op op) {
|
||||||
|
__m512 a_lo, a_hi;
|
||||||
|
__m512 b_lo, b_hi;
|
||||||
|
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
|
||||||
|
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
|
||||||
|
auto o1 = op(a_lo, b_lo);
|
||||||
|
auto o2 = op(a_hi, b_hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Op>
|
||||||
|
Vectorized<BFloat16> static inline bfloat16_compare_as_fp32(const Vectorized<BFloat16>& a,
|
||||||
|
const Vectorized<BFloat16>& b, Op op) {
|
||||||
|
__m512 a_lo, a_hi;
|
||||||
|
__m512 b_lo, b_hi;
|
||||||
|
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
|
||||||
|
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
|
||||||
|
auto o1 = op(a_lo, b_lo);
|
||||||
|
auto o2 = op(a_hi, b_hi);
|
||||||
|
return merge_compare_result(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>(const Vectorized<BFloat16>& other) const {
|
||||||
|
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
|
||||||
|
auto zero_vec = _mm512_set1_epi32(0);
|
||||||
|
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<(const Vectorized<BFloat16>& other) const {
|
||||||
|
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
|
||||||
|
auto zero_vec = _mm512_set1_epi32(0);
|
||||||
|
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>=(const Vectorized<BFloat16>& other) const {
|
||||||
|
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
|
||||||
|
auto zero_vec = _mm512_set1_epi32(0);
|
||||||
|
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<=(const Vectorized<BFloat16>& other) const {
|
||||||
|
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
|
||||||
|
auto zero_vec = _mm512_set1_epi32(0);
|
||||||
|
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(const Vectorized<BFloat16>& other) const {
|
||||||
|
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
|
||||||
|
auto zero_vec = _mm512_set1_epi32(0);
|
||||||
|
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator!=(const Vectorized<BFloat16>& other) const {
|
||||||
|
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
|
||||||
|
auto zero_vec = _mm512_set1_epi32(0);
|
||||||
|
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> inline operator+(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
|
return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); });
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> inline operator-(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
|
return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); });
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> inline operator*(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
|
return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); });
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> inline operator/(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
|
return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); });
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> inline operator&(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
|
return _mm512_and_si512(a, b);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> inline operator|(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
|
return _mm512_or_si512(a, b);
|
||||||
|
}
|
||||||
|
Vectorized<BFloat16> inline operator^(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
|
return _mm512_xor_si512(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> Vectorized<BFloat16>::eq(const Vectorized<BFloat16>& other) const {
|
||||||
|
return (*this == other) & Vectorized<BFloat16>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> Vectorized<BFloat16>::ne(const Vectorized<BFloat16>& other) const {
|
||||||
|
return (*this != other) & Vectorized<BFloat16>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> Vectorized<BFloat16>::gt(const Vectorized<BFloat16>& other) const {
|
||||||
|
return (*this > other) & Vectorized<BFloat16>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> Vectorized<BFloat16>::ge(const Vectorized<BFloat16>& other) const {
|
||||||
|
return (*this >= other) & Vectorized<BFloat16>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> Vectorized<BFloat16>::lt(const Vectorized<BFloat16>& other) const {
|
||||||
|
return (*this < other) & Vectorized<BFloat16>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<BFloat16> Vectorized<BFloat16>::le(const Vectorized<BFloat16>& other) const {
|
||||||
|
return (*this <= other) & Vectorized<BFloat16>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// frac. Implement this here so we can use subtraction
|
||||||
|
Vectorized<BFloat16> Vectorized<BFloat16>::frac() const {
|
||||||
|
return *this - this->trunc();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
||||||
|
// either input is a NaN.
|
||||||
|
template <>
|
||||||
|
Vectorized<BFloat16> inline maximum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
|
__m512 a_lo, a_hi;
|
||||||
|
__m512 b_lo, b_hi;
|
||||||
|
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
|
||||||
|
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
|
||||||
|
auto max_lo = _mm512_max_ps(a_lo, b_lo);
|
||||||
|
auto max_hi = _mm512_max_ps(a_hi, b_hi);
|
||||||
|
auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
|
||||||
|
auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
|
||||||
|
auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask));
|
||||||
|
auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask));
|
||||||
|
// Exploit the fact that all-ones is a NaN.
|
||||||
|
auto o1 = _mm512_or_ps(max_lo, nan_lo);
|
||||||
|
auto o2 = _mm512_or_ps(max_hi, nan_hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
||||||
|
// either input is a NaN.
|
||||||
|
template <>
|
||||||
|
Vectorized<BFloat16> inline minimum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
|
||||||
|
__m512 a_lo, a_hi;
|
||||||
|
__m512 b_lo, b_hi;
|
||||||
|
__m512i zero_vec = _mm512_set1_epi32(0);
|
||||||
|
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
|
||||||
|
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
|
||||||
|
auto min_lo = _mm512_min_ps(a_lo, b_lo);
|
||||||
|
auto min_hi = _mm512_min_ps(a_hi, b_hi);
|
||||||
|
auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
|
||||||
|
auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
|
||||||
|
auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
// Exploit the fact that all-ones is a NaN.
|
||||||
|
auto o1 = _mm512_or_ps(min_lo, nan_lo);
|
||||||
|
auto o2 = _mm512_or_ps(min_hi, nan_hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<BFloat16> inline clamp(const Vectorized<BFloat16>& a,
|
||||||
|
const Vectorized<BFloat16>& min, const Vectorized<BFloat16>& max) {
|
||||||
|
__m512 a_lo, a_hi;
|
||||||
|
__m512 min_lo, min_hi;
|
||||||
|
__m512 max_lo, max_hi;
|
||||||
|
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
|
||||||
|
cvtbf16_fp32(__m512i(min), min_lo, min_hi);
|
||||||
|
cvtbf16_fp32(__m512i(max), max_lo, max_hi);
|
||||||
|
auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo));
|
||||||
|
auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi));
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<BFloat16> inline clamp_max(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& max) {
|
||||||
|
__m512 a_lo, a_hi;
|
||||||
|
__m512 max_lo, max_hi;
|
||||||
|
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
|
||||||
|
cvtbf16_fp32(__m512i(max), max_lo, max_hi);
|
||||||
|
auto o1 = _mm512_min_ps(max_lo, a_lo);
|
||||||
|
auto o2 = _mm512_min_ps(max_hi, a_hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<BFloat16> inline clamp_min(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& min) {
|
||||||
|
__m512 a_lo, a_hi;
|
||||||
|
__m512 min_lo, min_hi;
|
||||||
|
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
|
||||||
|
cvtbf16_fp32(__m512i(min), min_lo, min_hi);
|
||||||
|
auto o1 = _mm512_max_ps(min_lo, a_lo);
|
||||||
|
auto o2 = _mm512_max_ps(min_hi, a_hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) {
|
||||||
|
int64_t i;
|
||||||
|
#pragma unroll
|
||||||
|
for (i = 0; i <= (n - Vectorized<BFloat16>::size()); i += Vectorized<BFloat16>::size()) {
|
||||||
|
auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i)));
|
||||||
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc);
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (; i < n; i++) {
|
||||||
|
dst[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<BFloat16> inline fmadd(const Vectorized<BFloat16>& a,
|
||||||
|
const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& c) {
|
||||||
|
__m512 a_lo, a_hi;
|
||||||
|
__m512 b_lo, b_hi;
|
||||||
|
__m512 c_lo, c_hi;
|
||||||
|
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
|
||||||
|
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
|
||||||
|
cvtbf16_fp32(__m512i(c), c_lo, c_hi);
|
||||||
|
auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo);
|
||||||
|
auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi);
|
||||||
|
return cvtfp32_bf16(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) {
|
||||||
|
__m512 o1, o2;
|
||||||
|
cvtbf16_fp32(__m512i(a), o1, o2);
|
||||||
|
return std::make_tuple(o1, o2);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
return cvtfp32_bf16(__m512(a), __m512(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
#else //defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
|
||||||
|
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) {
|
||||||
|
constexpr int64_t K = Vectorized<BFloat16>::size();
|
||||||
|
__at_align__ float arr[K];
|
||||||
|
__at_align__ BFloat16 arr2[K];
|
||||||
|
a.store(arr2);
|
||||||
|
convert(arr2, arr, K);
|
||||||
|
return std::make_tuple(
|
||||||
|
Vectorized<float>::loadu(arr),
|
||||||
|
Vectorized<float>::loadu(arr + Vectorized<float>::size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
constexpr int64_t K = Vectorized<BFloat16>::size();
|
||||||
|
__at_align__ float arr[K];
|
||||||
|
__at_align__ BFloat16 arr2[K];
|
||||||
|
a.store(arr);
|
||||||
|
b.store(arr + Vectorized<float>::size());
|
||||||
|
convert(arr, arr2, K);
|
||||||
|
return Vectorized<BFloat16>::loadu(arr2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
|
||||||
|
auto values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data));
|
||||||
|
__m512 out_values;
|
||||||
|
cvtbf16_fp32(values, out_values);
|
||||||
|
out = out_values;
|
||||||
|
}
|
||||||
|
|
||||||
|
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out1, Vectorized<float>& out2) {
|
||||||
|
auto vec = Vectorized<c10::BFloat16>::loadu(data);
|
||||||
|
__m512 out1_values, out2_values;
|
||||||
|
cvtbf16_fp32(vec, out1_values, out2_values);
|
||||||
|
out1 = out1_values;
|
||||||
|
out2 = out2_values;
|
||||||
|
}
|
||||||
|
#else // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
|
||||||
|
__at_align__ float values[Vectorized<float>::size()];
|
||||||
|
for (int k = 0; k < Vectorized<float>::size(); ++k) {
|
||||||
|
values[k] = data[k];
|
||||||
|
}
|
||||||
|
out = Vectorized<float>::loadu(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out1, Vectorized<float>& out2) {
|
||||||
|
load_fp32_from_bf16(data, out1);
|
||||||
|
data += Vectorized<float>::size();
|
||||||
|
load_fp32_from_bf16(data, out2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}}}
|
||||||
526
aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
Normal file
526
aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
Normal file
|
|
@ -0,0 +1,526 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
|
#include <c10/util/complex.h>
|
||||||
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
#include <sleef.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
namespace vec {
|
||||||
|
// See Note [Acceptable use of anonymous namespace in header]
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
|
||||||
|
template <> class Vectorized<c10::complex<double>> {
|
||||||
|
private:
|
||||||
|
__m512d values;
|
||||||
|
static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
|
||||||
|
public:
|
||||||
|
using value_type = c10::complex<double>;
|
||||||
|
using size_type = int;
|
||||||
|
static constexpr size_type size() {
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
Vectorized() {}
|
||||||
|
Vectorized(__m512d v) : values(v) {}
|
||||||
|
Vectorized(c10::complex<double> val) {
|
||||||
|
double real_value = val.real();
|
||||||
|
double imag_value = val.imag();
|
||||||
|
values = _mm512_setr_pd(real_value, imag_value, real_value, imag_value,
|
||||||
|
real_value, imag_value, real_value, imag_value);
|
||||||
|
}
|
||||||
|
Vectorized(c10::complex<double> val1, c10::complex<double> val2,
|
||||||
|
c10::complex<double> val3, c10::complex<double> val4) {
|
||||||
|
values = _mm512_setr_pd(val1.real(), val1.imag(),
|
||||||
|
val2.real(), val2.imag(),
|
||||||
|
val3.real(), val3.imag(),
|
||||||
|
val4.real(), val4.imag());
|
||||||
|
}
|
||||||
|
operator __m512d() const {
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
template <int64_t mask>
|
||||||
|
static Vectorized<c10::complex<double>> blend(const Vectorized<c10::complex<double>>& a,
|
||||||
|
const Vectorized<c10::complex<double>>& b) {
|
||||||
|
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
|
||||||
|
// NOLINTNEXTLINE(clang-diagnostic-warning)
|
||||||
|
switch (mask) {
|
||||||
|
case 0:
|
||||||
|
return a;
|
||||||
|
case 1:
|
||||||
|
return _mm512_mask_blend_pd(0x03, a.values, b.values); //b0000 0001 = b0000 0011
|
||||||
|
case 2:
|
||||||
|
return _mm512_mask_blend_pd(0x0C, a.values, b.values); //b0000 0010 = b0000 1100
|
||||||
|
case 3:
|
||||||
|
return _mm512_mask_blend_pd(0x0F, a.values, b.values); //b0000 0011 = b0000 1111
|
||||||
|
case 4:
|
||||||
|
return _mm512_mask_blend_pd(0x30, a.values, b.values); //b0000 0100 = b0011 0000
|
||||||
|
case 5:
|
||||||
|
return _mm512_mask_blend_pd(0x33, a.values, b.values); //b0000 0101 = b0011 0011
|
||||||
|
case 6:
|
||||||
|
return _mm512_mask_blend_pd(0x3C, a.values, b.values); //b0000 0110 = b0011 1100
|
||||||
|
case 7:
|
||||||
|
return _mm512_mask_blend_pd(0x3F, a.values, b.values); //b0000 0111 = b0011 1111
|
||||||
|
case 8:
|
||||||
|
return _mm512_mask_blend_pd(0xC0, a.values, b.values); //b0000 1000 = b1100 0000
|
||||||
|
case 9:
|
||||||
|
return _mm512_mask_blend_pd(0xC3, a.values, b.values); //b0000 1001 = b1100 0011
|
||||||
|
case 10:
|
||||||
|
return _mm512_mask_blend_pd(0xCC, a.values, b.values); //b0000 1010 = b1100 1100
|
||||||
|
case 11:
|
||||||
|
return _mm512_mask_blend_pd(0xCF, a.values, b.values); //b0000 1011 = b1100 1111
|
||||||
|
case 12:
|
||||||
|
return _mm512_mask_blend_pd(0xF0, a.values, b.values); //b0000 1100 = b1111 0000
|
||||||
|
case 13:
|
||||||
|
return _mm512_mask_blend_pd(0xF3, a.values, b.values); //b0000 1101 = b1111 0011
|
||||||
|
case 14:
|
||||||
|
return _mm512_mask_blend_pd(0xFC, a.values, b.values); //b0000 1110 = b1111 1100
|
||||||
|
case 15:
|
||||||
|
return _mm512_mask_blend_pd(0xFF, a.values, b.values); //b0000 1111 = b1111 1111
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
static Vectorized<c10::complex<double>> blendv(const Vectorized<c10::complex<double>>& a,
|
||||||
|
const Vectorized<c10::complex<double>>& b,
|
||||||
|
const Vectorized<c10::complex<double>>& mask) {
|
||||||
|
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
|
||||||
|
auto mask_ = _mm512_unpacklo_pd(mask.values, mask.values);
|
||||||
|
auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF);
|
||||||
|
auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask_), all_ones, _MM_CMPINT_EQ);
|
||||||
|
return _mm512_mask_blend_pd(mmask, a.values, b.values);
|
||||||
|
}
|
||||||
|
template<typename step_t>
|
||||||
|
static Vectorized<c10::complex<double>> arange(c10::complex<double> base = 0.,
|
||||||
|
step_t step = static_cast<step_t>(1)) {
|
||||||
|
return Vectorized<c10::complex<double>>(base,
|
||||||
|
base + c10::complex<double>(1)*step,
|
||||||
|
base + c10::complex<double>(2)*step,
|
||||||
|
base + c10::complex<double>(3)*step);
|
||||||
|
}
|
||||||
|
static Vectorized<c10::complex<double>> set(const Vectorized<c10::complex<double>>& a,
|
||||||
|
const Vectorized<c10::complex<double>>& b,
|
||||||
|
int64_t count = size()) {
|
||||||
|
switch (count) {
|
||||||
|
case 0:
|
||||||
|
return a;
|
||||||
|
case 1:
|
||||||
|
return blend<1>(a, b);
|
||||||
|
case 2:
|
||||||
|
return blend<3>(a, b);
|
||||||
|
case 3:
|
||||||
|
return blend<7>(a, b);
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
static Vectorized<c10::complex<double>> loadu(const void* ptr, int64_t count = size()) {
|
||||||
|
if (count == size())
|
||||||
|
return _mm512_loadu_pd(reinterpret_cast<const double*>(ptr));
|
||||||
|
|
||||||
|
__at_align__ double tmp_values[2*size()];
|
||||||
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
for (auto i = 0; i < 2*size(); ++i) {
|
||||||
|
tmp_values[i] = 0.0;
|
||||||
|
}
|
||||||
|
std::memcpy(
|
||||||
|
tmp_values,
|
||||||
|
reinterpret_cast<const double*>(ptr),
|
||||||
|
count * sizeof(c10::complex<double>));
|
||||||
|
return _mm512_load_pd(tmp_values);
|
||||||
|
}
|
||||||
|
void store(void* ptr, int count = size()) const {
|
||||||
|
if (count == size()) {
|
||||||
|
_mm512_storeu_pd(reinterpret_cast<double*>(ptr), values);
|
||||||
|
} else if (count > 0) {
|
||||||
|
double tmp_values[2*size()];
|
||||||
|
_mm512_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
|
||||||
|
std::memcpy(ptr, tmp_values, count * sizeof(c10::complex<double>));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const c10::complex<double>& operator[](int idx) const = delete;
|
||||||
|
c10::complex<double>& operator[](int idx) = delete;
|
||||||
|
Vectorized<c10::complex<double>> map(c10::complex<double> (*const f)(const c10::complex<double> &)) const {
|
||||||
|
__at_align__ c10::complex<double> tmp[size()];
|
||||||
|
store(tmp);
|
||||||
|
for (int i = 0; i < size(); i++) {
|
||||||
|
tmp[i] = f(tmp[i]);
|
||||||
|
}
|
||||||
|
return loadu(tmp);
|
||||||
|
}
|
||||||
|
// AVX512 doesn't have horizontal add & horizontal sub instructions.
|
||||||
|
// TODO: hadd_pd() & hsub_pd() may have scope for improvement.
|
||||||
|
static inline __m512d hadd_pd(__m512d a, __m512d b) {
|
||||||
|
__m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0);
|
||||||
|
__m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1);
|
||||||
|
return _mm512_add_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
|
||||||
|
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
|
||||||
|
}
|
||||||
|
static inline __m512d hsub_pd(__m512d a, __m512d b) {
|
||||||
|
__m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0);
|
||||||
|
__m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1);
|
||||||
|
return _mm512_sub_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
|
||||||
|
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
|
||||||
|
}
|
||||||
|
__m512d abs_2_() const {
|
||||||
|
auto val_2 = _mm512_mul_pd(values, values); // a*a b*b
|
||||||
|
return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b
|
||||||
|
}
|
||||||
|
__m512d abs_() const {
|
||||||
|
return _mm512_sqrt_pd(abs_2_()); // abs abs
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> abs() const {
|
||||||
|
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
|
||||||
|
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
|
||||||
|
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
|
||||||
|
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
|
||||||
|
return _mm512_and_pd(abs_(), real_mask); // abs 0
|
||||||
|
}
|
||||||
|
__m512d angle_() const {
|
||||||
|
//angle = atan2(b/a)
|
||||||
|
auto b_a = _mm512_permute_pd(values, 0x55); // b a
|
||||||
|
return Sleef_atan2d8_u10(values, b_a); // 90-angle angle
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> angle() const {
|
||||||
|
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
|
||||||
|
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
|
||||||
|
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
|
||||||
|
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
|
||||||
|
auto angle = _mm512_permute_pd(angle_(), 0x55); // angle 90-angle
|
||||||
|
return _mm512_and_pd(angle, real_mask); // angle 0
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> sgn() const {
|
||||||
|
auto abs = abs_();
|
||||||
|
auto zero = _mm512_setzero_pd();
|
||||||
|
auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ);
|
||||||
|
auto mask_vec = _mm512_mask_set1_epi64(_mm512_castpd_si512(zero), mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF);
|
||||||
|
auto abs_val = Vectorized(abs);
|
||||||
|
|
||||||
|
auto div = values / abs_val.values; // x / abs(x)
|
||||||
|
|
||||||
|
return blendv(div, zero, _mm512_castsi512_pd(mask_vec));
|
||||||
|
}
|
||||||
|
__m512d real_() const {
|
||||||
|
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
|
||||||
|
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
|
||||||
|
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
|
||||||
|
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
|
||||||
|
return _mm512_and_pd(values, real_mask);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> real() const {
|
||||||
|
return real_();
|
||||||
|
}
|
||||||
|
__m512d imag_() const {
|
||||||
|
const __m512d imag_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
|
||||||
|
0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
|
||||||
|
0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
|
||||||
|
0x0000000000000000, 0xFFFFFFFFFFFFFFFF));
|
||||||
|
return _mm512_and_pd(values, imag_mask);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> imag() const {
|
||||||
|
return _mm512_permute_pd(imag_(), 0x55); //b a
|
||||||
|
}
|
||||||
|
__m512d conj_() const {
|
||||||
|
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
|
||||||
|
return _mm512_xor_pd(values, sign_mask); // a -b
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> conj() const {
|
||||||
|
return conj_();
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> log() const {
|
||||||
|
// Most trigonomic ops use the log() op to improve complex number performance.
|
||||||
|
return map(std::log);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> log2() const {
|
||||||
|
const __m512d log2_ = _mm512_set1_pd(std::log(2));
|
||||||
|
return _mm512_div_pd(log(), log2_);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> log10() const {
|
||||||
|
const __m512d log10_ = _mm512_set1_pd(std::log(10));
|
||||||
|
return _mm512_div_pd(log(), log10_);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> log1p() const {
|
||||||
|
AT_ERROR("not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> asin() const {
|
||||||
|
// asin(x)
|
||||||
|
// = -i*ln(iz + sqrt(1 -z^2))
|
||||||
|
// = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
|
||||||
|
// = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
|
||||||
|
const __m512d one = _mm512_set1_pd(1);
|
||||||
|
|
||||||
|
auto conj = conj_();
|
||||||
|
auto b_a = _mm512_permute_pd(conj, 0x55); //-b a
|
||||||
|
auto ab = _mm512_mul_pd(conj, b_a); //-ab -ab
|
||||||
|
auto im = _mm512_add_pd(ab, ab); //-2ab -2ab
|
||||||
|
|
||||||
|
auto val_2 = _mm512_mul_pd(values, values); // a*a b*b
|
||||||
|
auto re = hsub_pd(val_2, _mm512_permute_pd(val_2, 0x55)); // a*a-b*b b*b-a*a
|
||||||
|
re = _mm512_sub_pd(one, re);
|
||||||
|
|
||||||
|
auto root = Vectorized(_mm512_mask_blend_pd(0xAA, re, im)).sqrt(); //sqrt(re + i*im)
|
||||||
|
auto ln = Vectorized(_mm512_add_pd(b_a, root)).log(); //ln(iz + sqrt())
|
||||||
|
return Vectorized(_mm512_permute_pd(ln.values, 0x55)).conj(); //-i*ln()
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> acos() const {
|
||||||
|
// acos(x) = pi/2 - asin(x)
|
||||||
|
constexpr auto pi_2d = c10::pi<double> / 2;
|
||||||
|
const __m512d pi_2 = _mm512_setr_pd(pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0);
|
||||||
|
return _mm512_sub_pd(pi_2, asin());
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> atan() const;
|
||||||
|
Vectorized<c10::complex<double>> atan2(const Vectorized<c10::complex<double>> &b) const {
|
||||||
|
AT_ERROR("not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> erf() const {
|
||||||
|
AT_ERROR("not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> erfc() const {
|
||||||
|
AT_ERROR("not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> exp() const {
|
||||||
|
//exp(a + bi)
|
||||||
|
// = exp(a)*(cos(b) + sin(b)i)
|
||||||
|
auto exp = Sleef_expd8_u10(values); //exp(a) exp(b)
|
||||||
|
exp = _mm512_mask_blend_pd(0xAA, exp, _mm512_permute_pd(exp, 0x55)); //exp(a) exp(a)
|
||||||
|
|
||||||
|
auto sin_cos = Sleef_sincosd8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
|
||||||
|
auto cos_sin = _mm512_mask_blend_pd(0xAA, _mm512_permute_pd(sin_cos.y, 0x55),
|
||||||
|
sin_cos.x); //cos(b) sin(b)
|
||||||
|
return _mm512_mul_pd(exp, cos_sin);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> expm1() const {
|
||||||
|
AT_ERROR("not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> sin() const {
|
||||||
|
return map(std::sin);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> sinh() const {
|
||||||
|
return map(std::sinh);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> cos() const {
|
||||||
|
return map(std::cos);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> cosh() const {
|
||||||
|
return map(std::cosh);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> ceil() const {
|
||||||
|
return _mm512_ceil_pd(values);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> floor() const {
|
||||||
|
return _mm512_floor_pd(values);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> hypot(const Vectorized<c10::complex<double>> &b) const {
|
||||||
|
AT_ERROR("not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> igamma(const Vectorized<c10::complex<double>> &x) const {
|
||||||
|
AT_ERROR("not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> igammac(const Vectorized<c10::complex<double>> &x) const {
|
||||||
|
AT_ERROR("not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> neg() const {
|
||||||
|
auto zero = _mm512_setzero_pd();
|
||||||
|
return _mm512_sub_pd(zero, values);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> nextafter(const Vectorized<c10::complex<double>> &b) const {
|
||||||
|
AT_ERROR("not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> round() const {
|
||||||
|
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> tan() const {
|
||||||
|
return map(std::tan);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> tanh() const {
|
||||||
|
return map(std::tanh);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> trunc() const {
|
||||||
|
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> sqrt() const {
|
||||||
|
return map(std::sqrt);
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> reciprocal() const;
|
||||||
|
Vectorized<c10::complex<double>> rsqrt() const {
|
||||||
|
return sqrt().reciprocal();
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> pow(const Vectorized<c10::complex<double>> &exp) const {
|
||||||
|
__at_align__ c10::complex<double> x_tmp[size()];
|
||||||
|
__at_align__ c10::complex<double> y_tmp[size()];
|
||||||
|
store(x_tmp);
|
||||||
|
exp.store(y_tmp);
|
||||||
|
for (int i = 0; i < size(); i++) {
|
||||||
|
x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
|
||||||
|
}
|
||||||
|
return loadu(x_tmp);
|
||||||
|
}
|
||||||
|
// Comparison using the _CMP_**_OQ predicate.
|
||||||
|
// `O`: get false if an operand is NaN
|
||||||
|
// `Q`: do not raise if an operand is NaN
|
||||||
|
Vectorized<c10::complex<double>> operator==(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ);
|
||||||
|
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> operator!=(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ);
|
||||||
|
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> operator<(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
TORCH_CHECK(false, "not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> operator<=(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
TORCH_CHECK(false, "not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> operator>(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
TORCH_CHECK(false, "not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> operator>=(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
TORCH_CHECK(false, "not supported for complex numbers");
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<c10::complex<double>> eq(const Vectorized<c10::complex<double>>& other) const;
|
||||||
|
Vectorized<c10::complex<double>> ne(const Vectorized<c10::complex<double>>& other) const;
|
||||||
|
Vectorized<c10::complex<double>> lt(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
TORCH_CHECK(false, "not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> le(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
TORCH_CHECK(false, "not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> gt(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
TORCH_CHECK(false, "not supported for complex numbers");
|
||||||
|
}
|
||||||
|
Vectorized<c10::complex<double>> ge(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
TORCH_CHECK(false, "not supported for complex numbers");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> Vectorized<c10::complex<double>> inline operator+(const Vectorized<c10::complex<double>> &a,
|
||||||
|
const Vectorized<c10::complex<double>> &b) {
|
||||||
|
return _mm512_add_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> Vectorized<c10::complex<double>> inline operator-(const Vectorized<c10::complex<double>> &a,
|
||||||
|
const Vectorized<c10::complex<double>> &b) {
|
||||||
|
return _mm512_sub_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> Vectorized<c10::complex<double>> inline operator*(const Vectorized<c10::complex<double>> &a,
|
||||||
|
const Vectorized<c10::complex<double>> &b) {
|
||||||
|
//(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
|
||||||
|
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
|
||||||
|
auto ac_bd = _mm512_mul_pd(a, b); //ac bd
|
||||||
|
|
||||||
|
auto d_c = _mm512_permute_pd(b, 0x55); //d c
|
||||||
|
d_c = _mm512_xor_pd(sign_mask, d_c); //d -c
|
||||||
|
auto ad_bc = _mm512_mul_pd(a, d_c); //ad -bc
|
||||||
|
|
||||||
|
auto ret = Vectorized<c10::complex<double>>::hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> Vectorized<c10::complex<double>> inline operator/(const Vectorized<c10::complex<double>> &a,
|
||||||
|
const Vectorized<c10::complex<double>> &b) {
|
||||||
|
//re + im*i = (a + bi) / (c + di)
|
||||||
|
//re = (ac + bd)/abs_2()
|
||||||
|
//im = (bc - ad)/abs_2()
|
||||||
|
const __m512d sign_mask = _mm512_setr_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
|
||||||
|
auto ac_bd = _mm512_mul_pd(a, b); //ac bd
|
||||||
|
|
||||||
|
auto d_c = _mm512_permute_pd(b, 0x55); //d c
|
||||||
|
d_c = _mm512_xor_pd(sign_mask, d_c); //-d c
|
||||||
|
auto ad_bc = _mm512_mul_pd(a, d_c); //-ad bc
|
||||||
|
|
||||||
|
auto re_im = Vectorized<c10::complex<double>>::hadd_pd(ac_bd, ad_bc);//ac + bd bc - ad
|
||||||
|
return _mm512_div_pd(re_im, b.abs_2_());
|
||||||
|
}
|
||||||
|
|
||||||
|
// reciprocal. Implement this here so we can use multiplication.
|
||||||
|
Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::reciprocal() const{
|
||||||
|
//re + im*i = (a + bi) / (c + di)
|
||||||
|
//re = (ac + bd)/abs_2() = c/abs_2()
|
||||||
|
//im = (bc - ad)/abs_2() = d/abs_2()
|
||||||
|
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
|
||||||
|
auto c_d = _mm512_xor_pd(sign_mask, values); //c -d
|
||||||
|
return _mm512_div_pd(c_d, abs_2_());
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::atan() const {
|
||||||
|
// atan(x) = i/2 * ln((i + z)/(i - z))
|
||||||
|
const __m512d i = _mm512_setr_pd(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
|
||||||
|
const Vectorized i_half = _mm512_setr_pd(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5);
|
||||||
|
|
||||||
|
auto sum = Vectorized(_mm512_add_pd(i, values)); // a 1+b
|
||||||
|
auto sub = Vectorized(_mm512_sub_pd(i, values)); // -a 1-b
|
||||||
|
auto ln = (sum/sub).log(); // ln((i + z)/(i - z))
|
||||||
|
return i_half*ln; // i/2*ln()
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<c10::complex<double>> inline maximum(const Vectorized<c10::complex<double>>& a,
|
||||||
|
const Vectorized<c10::complex<double>>& b) {
|
||||||
|
auto zero_vec = _mm512_set1_epi64(0);
|
||||||
|
auto abs_a = a.abs_2_();
|
||||||
|
auto abs_b = b.abs_2_();
|
||||||
|
auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_LT_OQ);
|
||||||
|
auto max = _mm512_mask_blend_pd(mask, a, b);
|
||||||
|
// Exploit the fact that all-ones is a NaN.
|
||||||
|
auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q);
|
||||||
|
auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF);
|
||||||
|
return _mm512_or_pd(max, _mm512_castsi512_pd(isnan));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<c10::complex<double>> inline minimum(const Vectorized<c10::complex<double>>& a,
|
||||||
|
const Vectorized<c10::complex<double>>& b) {
|
||||||
|
auto zero_vec = _mm512_set1_epi64(0);
|
||||||
|
auto abs_a = a.abs_2_();
|
||||||
|
auto abs_b = b.abs_2_();
|
||||||
|
auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_GT_OQ);
|
||||||
|
auto min = _mm512_mask_blend_pd(mask, a, b);
|
||||||
|
// Exploit the fact that all-ones is a NaN.
|
||||||
|
auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q);
|
||||||
|
auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF);
|
||||||
|
return _mm512_or_pd(min, _mm512_castsi512_pd(isnan));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<c10::complex<double>> inline operator&(const Vectorized<c10::complex<double>>& a,
|
||||||
|
const Vectorized<c10::complex<double>>& b) {
|
||||||
|
return _mm512_and_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<c10::complex<double>> inline operator|(const Vectorized<c10::complex<double>>& a,
|
||||||
|
const Vectorized<c10::complex<double>>& b) {
|
||||||
|
return _mm512_or_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<c10::complex<double>> inline operator^(const Vectorized<c10::complex<double>>& a,
|
||||||
|
const Vectorized<c10::complex<double>>& b) {
|
||||||
|
return _mm512_xor_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::eq(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
return (*this == other) & Vectorized<c10::complex<double>>(_mm512_set1_pd(1.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::ne(const Vectorized<c10::complex<double>>& other) const {
|
||||||
|
return (*this != other) & Vectorized<c10::complex<double>>(_mm512_set1_pd(1.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}}}
|
||||||
1030
aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
Normal file
1030
aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
Normal file
File diff suppressed because it is too large
Load Diff
454
aten/src/ATen/cpu/vec/vec512/vec512_double.h
Normal file
454
aten/src/ATen/cpu/vec/vec512/vec512_double.h
Normal file
|
|
@ -0,0 +1,454 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
|
#if (defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
|
||||||
|
#include <sleef.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
namespace vec {
|
||||||
|
// See Note [Acceptable use of anonymous namespace in header]
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
|
||||||
|
template <> class Vectorized<double> {
|
||||||
|
private:
|
||||||
|
static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
|
||||||
|
public:
|
||||||
|
// values needs to be public for compilation with clang
|
||||||
|
// as vec512.h uses it
|
||||||
|
__m512d values;
|
||||||
|
using value_type = double;
|
||||||
|
using size_type = int;
|
||||||
|
static constexpr size_type size() {
|
||||||
|
return 8;
|
||||||
|
}
|
||||||
|
Vectorized() {}
|
||||||
|
Vectorized(__m512d v) : values(v) {}
|
||||||
|
Vectorized(double val) {
|
||||||
|
values = _mm512_set1_pd(val);
|
||||||
|
}
|
||||||
|
Vectorized(double val1, double val2, double val3, double val4,
|
||||||
|
double val5, double val6, double val7, double val8) {
|
||||||
|
values = _mm512_setr_pd(val1, val2, val3, val4, val5, val6, val7, val8);
|
||||||
|
}
|
||||||
|
operator __m512d() const {
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
template <int64_t mask>
|
||||||
|
static Vectorized<double> blend(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
return _mm512_mask_blend_pd(mask, a.values, b.values);
|
||||||
|
}
|
||||||
|
static Vectorized<double> blendv(const Vectorized<double>& a, const Vectorized<double>& b,
|
||||||
|
const Vectorized<double>& mask) {
|
||||||
|
auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF);
|
||||||
|
auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask.values), all_ones, _MM_CMPINT_EQ);
|
||||||
|
return _mm512_mask_blend_pd(mmask, a.values, b.values);
|
||||||
|
}
|
||||||
|
template<typename step_t>
|
||||||
|
static Vectorized<double> arange(double base = 0., step_t step = static_cast<step_t>(1)) {
|
||||||
|
return Vectorized<double>(base, base + step, base + 2 * step, base + 3 * step,
|
||||||
|
base + 4 * step, base + 5 * step, base + 6 * step,
|
||||||
|
base + 7 * step);
|
||||||
|
}
|
||||||
|
static Vectorized<double> set(const Vectorized<double>& a, const Vectorized<double>& b,
|
||||||
|
int64_t count = size()) {
|
||||||
|
switch (count) {
|
||||||
|
case 0:
|
||||||
|
return a;
|
||||||
|
case 1:
|
||||||
|
return blend<1>(a, b);
|
||||||
|
case 2:
|
||||||
|
return blend<3>(a, b);
|
||||||
|
case 3:
|
||||||
|
return blend<7>(a, b);
|
||||||
|
case 4:
|
||||||
|
return blend<15>(a, b);
|
||||||
|
case 5:
|
||||||
|
return blend<31>(a, b);
|
||||||
|
case 6:
|
||||||
|
return blend<63>(a, b);
|
||||||
|
case 7:
|
||||||
|
return blend<127>(a, b);
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
|
||||||
|
if (count == size())
|
||||||
|
return _mm512_loadu_pd(reinterpret_cast<const double*>(ptr));
|
||||||
|
|
||||||
|
|
||||||
|
__at_align__ double tmp_values[size()];
|
||||||
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
for (auto i = 0; i < size(); ++i) {
|
||||||
|
tmp_values[i] = 0.0;
|
||||||
|
}
|
||||||
|
std::memcpy(
|
||||||
|
tmp_values,
|
||||||
|
reinterpret_cast<const double*>(ptr),
|
||||||
|
count * sizeof(double));
|
||||||
|
return _mm512_load_pd(tmp_values);
|
||||||
|
}
|
||||||
|
void store(void* ptr, int count = size()) const {
|
||||||
|
if (count == size()) {
|
||||||
|
_mm512_storeu_pd(reinterpret_cast<double*>(ptr), values);
|
||||||
|
} else if (count > 0) {
|
||||||
|
double tmp_values[size()];
|
||||||
|
_mm512_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
|
||||||
|
std::memcpy(ptr, tmp_values, count * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const double& operator[](int idx) const = delete;
|
||||||
|
double& operator[](int idx) = delete;
|
||||||
|
int zero_mask() const {
|
||||||
|
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
|
||||||
|
__mmask8 cmp = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_EQ_OQ);
|
||||||
|
return static_cast<int32_t>(cmp);
|
||||||
|
}
|
||||||
|
Vectorized<double> isnan() const {
|
||||||
|
auto cmp_mask = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_UNORD_Q);
|
||||||
|
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
}
|
||||||
|
Vectorized<double> map(double (*const f)(double)) const {
|
||||||
|
__at_align__ double tmp[size()];
|
||||||
|
store(tmp);
|
||||||
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
tmp[i] = f(tmp[i]);
|
||||||
|
}
|
||||||
|
return loadu(tmp);
|
||||||
|
}
|
||||||
|
Vectorized<double> abs() const {
|
||||||
|
auto mask = _mm512_set1_pd(-0.f);
|
||||||
|
return _mm512_andnot_pd(mask, values);
|
||||||
|
}
|
||||||
|
Vectorized<double> angle() const {
|
||||||
|
const auto zero_vec = _mm512_castsi512_pd(zero_vector);
|
||||||
|
const auto nan_vec = _mm512_set1_pd(NAN);
|
||||||
|
const auto not_nan_mask = _mm512_cmp_pd_mask(values, values, _CMP_EQ_OQ);
|
||||||
|
const auto not_nan = _mm512_mask_set1_epi64(zero_vector, not_nan_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF);
|
||||||
|
const auto nan_mask = _mm512_cmp_pd_mask(_mm512_castsi512_pd(not_nan),
|
||||||
|
zero_vec, _CMP_EQ_OQ);
|
||||||
|
const auto pi = _mm512_set1_pd(c10::pi<double>);
|
||||||
|
|
||||||
|
const auto neg_mask = _mm512_cmp_pd_mask(values, zero_vec, _CMP_LT_OQ);
|
||||||
|
auto angle = _mm512_mask_blend_pd(neg_mask, zero_vec, pi);
|
||||||
|
angle = _mm512_mask_blend_pd(nan_mask, angle, nan_vec);
|
||||||
|
return angle;
|
||||||
|
}
|
||||||
|
Vectorized<double> real() const {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Vectorized<double> imag() const {
|
||||||
|
return _mm512_set1_pd(0);
|
||||||
|
}
|
||||||
|
Vectorized<double> conj() const {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Vectorized<double> acos() const {
|
||||||
|
return Vectorized<double>(Sleef_acosd8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> asin() const {
|
||||||
|
return Vectorized<double>(Sleef_asind8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> atan() const {
|
||||||
|
return Vectorized<double>(Sleef_atand8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> atan2(const Vectorized<double> &b) const {
|
||||||
|
return Vectorized<double>(Sleef_atan2d8_u10(values, b));
|
||||||
|
}
|
||||||
|
Vectorized<double> copysign(const Vectorized<double> &sign) const {
|
||||||
|
return Vectorized<double>(Sleef_copysignd8(values, sign));
|
||||||
|
}
|
||||||
|
Vectorized<double> erf() const {
|
||||||
|
return Vectorized<double>(Sleef_erfd8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> erfc() const {
|
||||||
|
return Vectorized<double>(Sleef_erfcd8_u15(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> erfinv() const {
|
||||||
|
return map(calc_erfinv);
|
||||||
|
}
|
||||||
|
Vectorized<double> exp() const {
|
||||||
|
return Vectorized<double>(Sleef_expd8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> expm1() const {
|
||||||
|
return Vectorized<double>(Sleef_expm1d8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> fmod(const Vectorized<double>& q) const {
|
||||||
|
return Vectorized<double>(Sleef_fmodd8(values, q));
|
||||||
|
}
|
||||||
|
Vectorized<double> hypot(const Vectorized<double> &b) const {
|
||||||
|
return Vectorized<double>(Sleef_hypotd8_u05(values, b));
|
||||||
|
}
|
||||||
|
Vectorized<double> i0() const {
|
||||||
|
return map(calc_i0);
|
||||||
|
}
|
||||||
|
Vectorized<double> i0e() const {
|
||||||
|
return map(calc_i0e);
|
||||||
|
}
|
||||||
|
Vectorized<double> igamma(const Vectorized<double> &x) const {
|
||||||
|
__at_align__ double tmp[size()];
|
||||||
|
__at_align__ double tmp_x[size()];
|
||||||
|
store(tmp);
|
||||||
|
x.store(tmp_x);
|
||||||
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
|
||||||
|
}
|
||||||
|
return loadu(tmp);
|
||||||
|
}
|
||||||
|
Vectorized<double> igammac(const Vectorized<double> &x) const {
|
||||||
|
__at_align__ double tmp[size()];
|
||||||
|
__at_align__ double tmp_x[size()];
|
||||||
|
store(tmp);
|
||||||
|
x.store(tmp_x);
|
||||||
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
|
||||||
|
}
|
||||||
|
return loadu(tmp);
|
||||||
|
}
|
||||||
|
Vectorized<double> log() const {
|
||||||
|
return Vectorized<double>(Sleef_logd8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> log2() const {
|
||||||
|
return Vectorized<double>(Sleef_log2d8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> log10() const {
|
||||||
|
return Vectorized<double>(Sleef_log10d8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> log1p() const {
|
||||||
|
return Vectorized<double>(Sleef_log1pd8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> sin() const {
|
||||||
|
return Vectorized<double>(Sleef_sind8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> sinh() const {
|
||||||
|
return Vectorized<double>(Sleef_sinhd8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> cos() const {
|
||||||
|
return Vectorized<double>(Sleef_cosd8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> cosh() const {
|
||||||
|
return Vectorized<double>(Sleef_coshd8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> ceil() const {
|
||||||
|
return _mm512_ceil_pd(values);
|
||||||
|
}
|
||||||
|
Vectorized<double> floor() const {
|
||||||
|
return _mm512_floor_pd(values);
|
||||||
|
}
|
||||||
|
Vectorized<double> frac() const;
|
||||||
|
Vectorized<double> neg() const {
|
||||||
|
return _mm512_xor_pd(_mm512_set1_pd(-0.), values);
|
||||||
|
}
|
||||||
|
Vectorized<double> nextafter(const Vectorized<double> &b) const {
|
||||||
|
return Vectorized<double>(Sleef_nextafterd8(values, b));
|
||||||
|
}
|
||||||
|
Vectorized<double> round() const {
|
||||||
|
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||||
|
}
|
||||||
|
Vectorized<double> tan() const {
|
||||||
|
return Vectorized<double>(Sleef_tand8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> tanh() const {
|
||||||
|
return Vectorized<double>(Sleef_tanhd8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> trunc() const {
|
||||||
|
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
|
||||||
|
}
|
||||||
|
Vectorized<double> lgamma() const {
|
||||||
|
return Vectorized<double>(Sleef_lgammad8_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> sqrt() const {
|
||||||
|
return _mm512_sqrt_pd(values);
|
||||||
|
}
|
||||||
|
Vectorized<double> reciprocal() const {
|
||||||
|
return _mm512_div_pd(_mm512_set1_pd(1), values);
|
||||||
|
}
|
||||||
|
Vectorized<double> rsqrt() const {
|
||||||
|
return _mm512_div_pd(_mm512_set1_pd(1), _mm512_sqrt_pd(values));
|
||||||
|
}
|
||||||
|
Vectorized<double> pow(const Vectorized<double> &b) const {
|
||||||
|
return Vectorized<double>(Sleef_powd8_u10(values, b));
|
||||||
|
}
|
||||||
|
// Comparison using the _CMP_**_OQ predicate.
|
||||||
|
// `O`: get false if an operand is NaN
|
||||||
|
// `Q`: do not raise if an operand is NaN
|
||||||
|
Vectorized<double> operator==(const Vectorized<double>& other) const {
|
||||||
|
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ);
|
||||||
|
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> operator!=(const Vectorized<double>& other) const {
|
||||||
|
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ);
|
||||||
|
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> operator<(const Vectorized<double>& other) const {
|
||||||
|
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LT_OQ);
|
||||||
|
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> operator<=(const Vectorized<double>& other) const {
|
||||||
|
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LE_OQ);
|
||||||
|
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> operator>(const Vectorized<double>& other) const {
|
||||||
|
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GT_OQ);
|
||||||
|
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> operator>=(const Vectorized<double>& other) const {
|
||||||
|
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GE_OQ);
|
||||||
|
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> eq(const Vectorized<double>& other) const;
|
||||||
|
Vectorized<double> ne(const Vectorized<double>& other) const;
|
||||||
|
Vectorized<double> lt(const Vectorized<double>& other) const;
|
||||||
|
Vectorized<double> le(const Vectorized<double>& other) const;
|
||||||
|
Vectorized<double> gt(const Vectorized<double>& other) const;
|
||||||
|
Vectorized<double> ge(const Vectorized<double>& other) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline operator+(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
return _mm512_add_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline operator-(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
return _mm512_sub_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline operator*(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
return _mm512_mul_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline operator/(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
return _mm512_div_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// frac. Implement this here so we can use subtraction.
|
||||||
|
Vectorized<double> Vectorized<double>::frac() const {
|
||||||
|
return *this - this->trunc();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
||||||
|
// either input is a NaN.
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline maximum(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
auto zero_vec = _mm512_set1_epi64(0);
|
||||||
|
Vectorized<double> max = _mm512_max_pd(a, b);
|
||||||
|
auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q);
|
||||||
|
auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
// Exploit the fact that all-ones is a NaN.
|
||||||
|
return _mm512_or_pd(max, isnan);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
||||||
|
// either input is a NaN.
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline minimum(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
auto zero_vec = _mm512_set1_epi64(0);
|
||||||
|
Vectorized<double> min = _mm512_min_pd(a, b);
|
||||||
|
auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q);
|
||||||
|
auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask,
|
||||||
|
0xFFFFFFFFFFFFFFFF));
|
||||||
|
// Exploit the fact that all-ones is a NaN.
|
||||||
|
return _mm512_or_pd(min, isnan);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline clamp(const Vectorized<double>& a, const Vectorized<double>& min, const Vectorized<double>& max) {
|
||||||
|
return _mm512_min_pd(max, _mm512_max_pd(min, a));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline clamp_min(const Vectorized<double>& a, const Vectorized<double>& min) {
|
||||||
|
return _mm512_max_pd(min, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline clamp_max(const Vectorized<double>& a, const Vectorized<double>& max) {
|
||||||
|
return _mm512_min_pd(max, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline operator&(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
return _mm512_and_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline operator|(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
return _mm512_or_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline operator^(const Vectorized<double>& a, const Vectorized<double>& b) {
|
||||||
|
return _mm512_xor_pd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> Vectorized<double>::eq(const Vectorized<double>& other) const {
|
||||||
|
return (*this == other) & Vectorized<double>(1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> Vectorized<double>::ne(const Vectorized<double>& other) const {
|
||||||
|
return (*this != other) & Vectorized<double>(1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> Vectorized<double>::gt(const Vectorized<double>& other) const {
|
||||||
|
return (*this > other) & Vectorized<double>(1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> Vectorized<double>::ge(const Vectorized<double>& other) const {
|
||||||
|
return (*this >= other) & Vectorized<double>(1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> Vectorized<double>::lt(const Vectorized<double>& other) const {
|
||||||
|
return (*this < other) & Vectorized<double>(1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<double> Vectorized<double>::le(const Vectorized<double>& other) const {
|
||||||
|
return (*this <= other) & Vectorized<double>(1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline void convert(const double* src, double* dst, int64_t n) {
|
||||||
|
int64_t i;
|
||||||
|
#pragma unroll
|
||||||
|
for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
|
||||||
|
_mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i));
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (; i < n; i++) {
|
||||||
|
dst[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<double> inline fmadd(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
|
||||||
|
return _mm512_fmadd_pd(a, b, c);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}}}
|
||||||
469
aten/src/ATen/cpu/vec/vec512/vec512_float.h
Normal file
469
aten/src/ATen/cpu/vec/vec512/vec512_float.h
Normal file
|
|
@ -0,0 +1,469 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||||
|
// See Note [Do not compile initializers with AVX]
|
||||||
|
|
||||||
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
|
#include <ATen/cpu/vec/vec_base.h>
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
#include <sleef.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
namespace vec {
|
||||||
|
// See Note [Acceptable use of anonymous namespace in header]
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||||
|
|
||||||
|
template <> class Vectorized<float> {
|
||||||
|
private:
|
||||||
|
static constexpr __m512i zero_vec {0, 0, 0, 0, 0, 0, 0, 0};
|
||||||
|
public:
|
||||||
|
__m512 values;
|
||||||
|
using value_type = float;
|
||||||
|
using size_type = int;
|
||||||
|
static constexpr size_type size() {
|
||||||
|
return 16;
|
||||||
|
}
|
||||||
|
Vectorized() {}
|
||||||
|
Vectorized(__m512 v) : values(v) {}
|
||||||
|
Vectorized(float val) {
|
||||||
|
values = _mm512_set1_ps(val);
|
||||||
|
}
|
||||||
|
Vectorized(float val1, float val2, float val3, float val4,
|
||||||
|
float val5, float val6, float val7, float val8,
|
||||||
|
float val9, float val10, float val11, float val12,
|
||||||
|
float val13, float val14, float val15, float val16) {
|
||||||
|
values = _mm512_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8,
|
||||||
|
val9, val10, val11, val12, val13, val14, val15, val16);
|
||||||
|
}
|
||||||
|
operator __m512() const {
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
template <int64_t mask>
|
||||||
|
static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
return _mm512_mask_blend_ps(mask, a.values, b.values);
|
||||||
|
}
|
||||||
|
static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b,
|
||||||
|
const Vectorized<float>& mask) {
|
||||||
|
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
|
||||||
|
auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ);
|
||||||
|
return _mm512_mask_blend_ps(mmask, a.values, b.values);
|
||||||
|
}
|
||||||
|
template<typename step_t>
|
||||||
|
static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) {
|
||||||
|
return Vectorized<float>(
|
||||||
|
base, base + step, base + 2 * step, base + 3 * step,
|
||||||
|
base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
|
||||||
|
base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
|
||||||
|
base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
|
||||||
|
}
|
||||||
|
static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b,
|
||||||
|
int64_t count = size()) {
|
||||||
|
switch (count) {
|
||||||
|
case 0:
|
||||||
|
return a;
|
||||||
|
case 1:
|
||||||
|
return blend<1>(a, b);
|
||||||
|
case 2:
|
||||||
|
return blend<3>(a, b);
|
||||||
|
case 3:
|
||||||
|
return blend<7>(a, b);
|
||||||
|
case 4:
|
||||||
|
return blend<15>(a, b);
|
||||||
|
case 5:
|
||||||
|
return blend<31>(a, b);
|
||||||
|
case 6:
|
||||||
|
return blend<63>(a, b);
|
||||||
|
case 7:
|
||||||
|
return blend<127>(a, b);
|
||||||
|
case 8:
|
||||||
|
return blend<255>(a, b);
|
||||||
|
case 9:
|
||||||
|
return blend<511>(a, b);
|
||||||
|
case 10:
|
||||||
|
return blend<1023>(a, b);
|
||||||
|
case 11:
|
||||||
|
return blend<2047>(a, b);
|
||||||
|
case 12:
|
||||||
|
return blend<4095>(a, b);
|
||||||
|
case 13:
|
||||||
|
return blend<8191>(a, b);
|
||||||
|
case 14:
|
||||||
|
return blend<16383>(a, b);
|
||||||
|
case 15:
|
||||||
|
return blend<32767>(a, b);
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
|
||||||
|
if (count == size())
|
||||||
|
return _mm512_loadu_ps(reinterpret_cast<const float*>(ptr));
|
||||||
|
__at_align__ float tmp_values[size()];
|
||||||
|
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
|
||||||
|
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
|
||||||
|
// instructions while a loop would be compiled to one instruction.
|
||||||
|
for (auto i = 0; i < size(); ++i) {
|
||||||
|
tmp_values[i] = 0.0;
|
||||||
|
}
|
||||||
|
std::memcpy(
|
||||||
|
tmp_values, reinterpret_cast<const float*>(ptr), count * sizeof(float));
|
||||||
|
return _mm512_loadu_ps(tmp_values);
|
||||||
|
}
|
||||||
|
void store(void* ptr, int64_t count = size()) const {
|
||||||
|
if (count == size()) {
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(ptr), values);
|
||||||
|
} else if (count > 0) {
|
||||||
|
float tmp_values[size()];
|
||||||
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp_values), values);
|
||||||
|
std::memcpy(ptr, tmp_values, count * sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const float& operator[](int idx) const = delete;
|
||||||
|
float& operator[](int idx) = delete;
|
||||||
|
int zero_mask() const {
|
||||||
|
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
|
||||||
|
__mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ);
|
||||||
|
return static_cast<int32_t>(cmp);
|
||||||
|
}
|
||||||
|
Vectorized<float> isnan() const {
|
||||||
|
auto mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
}
|
||||||
|
Vectorized<float> map(float (*const f)(float)) const {
|
||||||
|
__at_align__ float tmp[size()];
|
||||||
|
store(tmp);
|
||||||
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
tmp[i] = f(tmp[i]);
|
||||||
|
}
|
||||||
|
return loadu(tmp);
|
||||||
|
}
|
||||||
|
Vectorized<float> abs() const {
|
||||||
|
auto mask = _mm512_set1_ps(-0.f);
|
||||||
|
return _mm512_andnot_ps(mask, values);
|
||||||
|
}
|
||||||
|
Vectorized<float> angle() const {
|
||||||
|
__m512 zero_vec = _mm512_set1_ps(0.f);
|
||||||
|
const auto nan_vec = _mm512_set1_ps(NAN);
|
||||||
|
const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
|
||||||
|
const auto not_nan_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec),
|
||||||
|
not_nan_mask, 0xFFFFFFFF);
|
||||||
|
const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(not_nan_vec),
|
||||||
|
zero_vec, _CMP_EQ_OQ);
|
||||||
|
const auto pi = _mm512_set1_ps(c10::pi<double>);
|
||||||
|
|
||||||
|
const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
|
||||||
|
auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
|
||||||
|
angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
|
||||||
|
return angle;
|
||||||
|
}
|
||||||
|
Vectorized<float> real() const {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Vectorized<float> imag() const {
|
||||||
|
return _mm512_set1_ps(0);
|
||||||
|
}
|
||||||
|
Vectorized<float> conj() const {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Vectorized<float> acos() const {
|
||||||
|
return Vectorized<float>(Sleef_acosf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> asin() const {
|
||||||
|
return Vectorized<float>(Sleef_asinf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> atan() const {
|
||||||
|
return Vectorized<float>(Sleef_atanf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> atan2(const Vectorized<float> &b) const {
|
||||||
|
return Vectorized<float>(Sleef_atan2f16_u10(values, b));
|
||||||
|
}
|
||||||
|
Vectorized<float> copysign(const Vectorized<float> &sign) const {
|
||||||
|
return Vectorized<float>(Sleef_copysignf16(values, sign));
|
||||||
|
}
|
||||||
|
Vectorized<float> erf() const {
|
||||||
|
return Vectorized<float>(Sleef_erff16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> erfc() const {
|
||||||
|
return Vectorized<float>(Sleef_erfcf16_u15(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> erfinv() const {
|
||||||
|
return map(calc_erfinv);
|
||||||
|
}
|
||||||
|
Vectorized<float> exp() const {
|
||||||
|
return Vectorized<float>(Sleef_expf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> expm1() const {
|
||||||
|
return Vectorized<float>(Sleef_expm1f16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> fmod(const Vectorized<float>& q) const {
|
||||||
|
return Vectorized<float>(Sleef_fmodf16(values, q));
|
||||||
|
}
|
||||||
|
Vectorized<float> log() const {
|
||||||
|
return Vectorized<float>(Sleef_logf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> log2() const {
|
||||||
|
return Vectorized<float>(Sleef_log2f16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> log10() const {
|
||||||
|
return Vectorized<float>(Sleef_log10f16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> log1p() const {
|
||||||
|
return Vectorized<float>(Sleef_log1pf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> frac() const;
|
||||||
|
Vectorized<float> sin() const {
|
||||||
|
return Vectorized<float>(Sleef_sinf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> sinh() const {
|
||||||
|
return Vectorized<float>(Sleef_sinhf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> cos() const {
|
||||||
|
return Vectorized<float>(Sleef_cosf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> cosh() const {
|
||||||
|
return Vectorized<float>(Sleef_coshf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> ceil() const {
|
||||||
|
return _mm512_ceil_ps(values);
|
||||||
|
}
|
||||||
|
Vectorized<float> floor() const {
|
||||||
|
return _mm512_floor_ps(values);
|
||||||
|
}
|
||||||
|
Vectorized<float> hypot(const Vectorized<float> &b) const {
|
||||||
|
return Vectorized<float>(Sleef_hypotf16_u05(values, b));
|
||||||
|
}
|
||||||
|
Vectorized<float> i0() const {
|
||||||
|
return map(calc_i0);
|
||||||
|
}
|
||||||
|
Vectorized<float> i0e() const {
|
||||||
|
return map(calc_i0e);
|
||||||
|
}
|
||||||
|
Vectorized<float> igamma(const Vectorized<float> &x) const {
|
||||||
|
__at_align__ float tmp[size()];
|
||||||
|
__at_align__ float tmp_x[size()];
|
||||||
|
store(tmp);
|
||||||
|
x.store(tmp_x);
|
||||||
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
|
||||||
|
}
|
||||||
|
return loadu(tmp);
|
||||||
|
}
|
||||||
|
Vectorized<float> igammac(const Vectorized<float> &x) const {
|
||||||
|
__at_align__ float tmp[size()];
|
||||||
|
__at_align__ float tmp_x[size()];
|
||||||
|
store(tmp);
|
||||||
|
x.store(tmp_x);
|
||||||
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
|
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
|
||||||
|
}
|
||||||
|
return loadu(tmp);
|
||||||
|
}
|
||||||
|
Vectorized<float> neg() const {
|
||||||
|
return _mm512_xor_ps(_mm512_set1_ps(-0.f), values);
|
||||||
|
}
|
||||||
|
Vectorized<float> nextafter(const Vectorized<float> &b) const {
|
||||||
|
return Vectorized<float>(Sleef_nextafterf16(values, b));
|
||||||
|
}
|
||||||
|
Vectorized<float> round() const {
|
||||||
|
return _mm512_roundscale_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||||
|
}
|
||||||
|
Vectorized<float> tan() const {
|
||||||
|
return Vectorized<float>(Sleef_tanf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> tanh() const {
|
||||||
|
return Vectorized<float>(Sleef_tanhf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> trunc() const {
|
||||||
|
return _mm512_roundscale_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
|
||||||
|
}
|
||||||
|
Vectorized<float> lgamma() const {
|
||||||
|
return Vectorized<float>(Sleef_lgammaf16_u10(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> sqrt() const {
|
||||||
|
return _mm512_sqrt_ps(values);
|
||||||
|
}
|
||||||
|
Vectorized<float> reciprocal() const {
|
||||||
|
return _mm512_div_ps(_mm512_set1_ps(1), values);
|
||||||
|
}
|
||||||
|
Vectorized<float> rsqrt() const {
|
||||||
|
return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values));
|
||||||
|
}
|
||||||
|
Vectorized<float> pow(const Vectorized<float> &b) const {
|
||||||
|
return Vectorized<float>(Sleef_powf16_u10(values, b));
|
||||||
|
}
|
||||||
|
// Comparison using the _CMP_**_OQ predicate.
|
||||||
|
// `O`: get false if an operand is NaN
|
||||||
|
// `Q`: do not raise if an operand is NaN
|
||||||
|
Vectorized<float> operator==(const Vectorized<float>& other) const {
|
||||||
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> operator!=(const Vectorized<float>& other) const {
|
||||||
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> operator<(const Vectorized<float>& other) const {
|
||||||
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> operator<=(const Vectorized<float>& other) const {
|
||||||
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> operator>(const Vectorized<float>& other) const {
|
||||||
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> operator>=(const Vectorized<float>& other) const {
|
||||||
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ);
|
||||||
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> eq(const Vectorized<float>& other) const;
|
||||||
|
Vectorized<float> ne(const Vectorized<float>& other) const;
|
||||||
|
Vectorized<float> gt(const Vectorized<float>& other) const;
|
||||||
|
Vectorized<float> ge(const Vectorized<float>& other) const;
|
||||||
|
Vectorized<float> lt(const Vectorized<float>& other) const;
|
||||||
|
Vectorized<float> le(const Vectorized<float>& other) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
return _mm512_add_ps(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
return _mm512_sub_ps(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
return _mm512_mul_ps(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
return _mm512_div_ps(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// frac. Implement this here so we can use subtraction
|
||||||
|
Vectorized<float> Vectorized<float>::frac() const {
|
||||||
|
return *this - this->trunc();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
||||||
|
// either input is a NaN.
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
auto zero_vec = _mm512_set1_epi32(0);
|
||||||
|
auto max = _mm512_max_ps(a, b);
|
||||||
|
auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
|
||||||
|
auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
// Exploit the fact that all-ones is a NaN.
|
||||||
|
return _mm512_or_ps(max, isnan);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
||||||
|
// either input is a NaN.
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
auto zero_vec = _mm512_set1_epi32(0);
|
||||||
|
auto min = _mm512_min_ps(a, b);
|
||||||
|
auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
|
||||||
|
auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
|
||||||
|
0xFFFFFFFF));
|
||||||
|
// Exploit the fact that all-ones is a NaN.
|
||||||
|
return _mm512_or_ps(min, isnan);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) {
|
||||||
|
return _mm512_min_ps(max, _mm512_max_ps(min, a));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) {
|
||||||
|
return _mm512_min_ps(max, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) {
|
||||||
|
return _mm512_max_ps(min, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
return _mm512_and_ps(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
return _mm512_or_ps(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||||
|
return _mm512_xor_ps(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const {
|
||||||
|
return (*this == other) & Vectorized<float>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const {
|
||||||
|
return (*this != other) & Vectorized<float>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const {
|
||||||
|
return (*this > other) & Vectorized<float>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const {
|
||||||
|
return (*this >= other) & Vectorized<float>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const {
|
||||||
|
return (*this < other) & Vectorized<float>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const {
|
||||||
|
return (*this <= other) & Vectorized<float>(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline void convert(const float* src, float* dst, int64_t n) {
|
||||||
|
int64_t i;
|
||||||
|
#pragma unroll
|
||||||
|
for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
|
||||||
|
_mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i));
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (; i < n; i++) {
|
||||||
|
dst[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
|
||||||
|
return _mm512_fmadd_ps(a, b, c);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}}}
|
||||||
1173
aten/src/ATen/cpu/vec/vec512/vec512_int.h
Normal file
1173
aten/src/ATen/cpu/vec/vec512/vec512_int.h
Normal file
File diff suppressed because it is too large
Load Diff
1195
aten/src/ATen/cpu/vec/vec512/vec512_qint.h
Normal file
1195
aten/src/ATen/cpu/vec/vec512/vec512_qint.h
Normal file
File diff suppressed because it is too large
Load Diff
|
|
@ -20,7 +20,7 @@
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <bitset>
|
#include <bitset>
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec256/intrinsics.h>
|
#include <ATen/cpu/vec/intrinsics.h>
|
||||||
#include <ATen/native/Math.h>
|
#include <ATen/native/Math.h>
|
||||||
#include <ATen/NumericUtils.h>
|
#include <ATen/NumericUtils.h>
|
||||||
#include <c10/util/C++17.h>
|
#include <c10/util/C++17.h>
|
||||||
|
|
@ -32,13 +32,28 @@
|
||||||
#include <c10/util/TypeCast.h>
|
#include <c10/util/TypeCast.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
|
|
||||||
|
// These macros helped us unify vec_base.h
|
||||||
|
#ifdef CPU_CAPABILITY_AVX512
|
||||||
#if defined(__GNUC__)
|
#if defined(__GNUC__)
|
||||||
#define __at_align32__ __attribute__((aligned(32)))
|
#define __at_align__ __attribute__((aligned(64)))
|
||||||
#elif defined(_WIN32)
|
#elif defined(_WIN32)
|
||||||
#define __at_align32__ __declspec(align(32))
|
#define __at_align__ __declspec(align(64))
|
||||||
#else
|
#else
|
||||||
#define __at_align32__
|
#define __at_align__
|
||||||
#endif
|
#endif
|
||||||
|
#define VECTOR_WIDTH 64
|
||||||
|
#define int_vector __m512i
|
||||||
|
#else // CPU_CAPABILITY_AVX512
|
||||||
|
#if defined(__GNUC__)
|
||||||
|
#define __at_align__ __attribute__((aligned(32)))
|
||||||
|
#elif defined(_WIN32)
|
||||||
|
#define __at_align__ __declspec(align(32))
|
||||||
|
#else
|
||||||
|
#define __at_align__
|
||||||
|
#endif
|
||||||
|
#define VECTOR_WIDTH 32
|
||||||
|
#define int_vector __m256i
|
||||||
|
#endif // CPU_CAPABILITY_AVX512
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace vec {
|
namespace vec {
|
||||||
|
|
@ -70,11 +85,11 @@ using int_same_size_t = typename int_of_size<sizeof(T)>::type;
|
||||||
|
|
||||||
// NOTE: If you specialize on a type, you must define all operations!
|
// NOTE: If you specialize on a type, you must define all operations!
|
||||||
|
|
||||||
// emulates vectorized types
|
// emulates Vectorized types
|
||||||
template <class T>
|
template <class T>
|
||||||
struct Vectorized {
|
struct Vectorized {
|
||||||
private:
|
private:
|
||||||
__at_align32__ T values[32 / sizeof(T)];
|
__at_align__ T values[VECTOR_WIDTH / sizeof(T)];
|
||||||
public:
|
public:
|
||||||
using value_type = T;
|
using value_type = T;
|
||||||
using size_type = int;
|
using size_type = int;
|
||||||
|
|
@ -111,7 +126,7 @@ public:
|
||||||
// identifier is odr-used or not, and in any case it's hard to tell if
|
// identifier is odr-used or not, and in any case it's hard to tell if
|
||||||
// a variable is odr-used or not. So best to just cut the problem at the root.
|
// a variable is odr-used or not. So best to just cut the problem at the root.
|
||||||
static constexpr size_type size() {
|
static constexpr size_type size() {
|
||||||
return 32 / sizeof(T);
|
return VECTOR_WIDTH / sizeof(T);
|
||||||
}
|
}
|
||||||
Vectorized() : values{0} {}
|
Vectorized() : values{0} {}
|
||||||
Vectorized(T val) {
|
Vectorized(T val) {
|
||||||
|
|
@ -134,60 +149,60 @@ public:
|
||||||
template <int64_t mask_>
|
template <int64_t mask_>
|
||||||
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
|
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||||
int64_t mask = mask_;
|
int64_t mask = mask_;
|
||||||
Vectorized vec;
|
Vectorized vector;
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
if (mask & 0x01) {
|
if (mask & 0x01) {
|
||||||
vec[i] = b[i];
|
vector[i] = b[i];
|
||||||
} else {
|
} else {
|
||||||
vec[i] = a[i];
|
vector[i] = a[i];
|
||||||
}
|
}
|
||||||
mask = mask >> 1;
|
mask = mask >> 1;
|
||||||
}
|
}
|
||||||
return vec;
|
return vector;
|
||||||
}
|
}
|
||||||
static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
|
static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
|
||||||
const Vectorized<T>& mask) {
|
const Vectorized<T>& mask) {
|
||||||
Vectorized vec;
|
Vectorized vector;
|
||||||
int_same_size_t<T> buffer[size()];
|
int_same_size_t<T> buffer[size()];
|
||||||
mask.store(buffer);
|
mask.store(buffer);
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
if (buffer[i] & 0x01)
|
if (buffer[i] & 0x01)
|
||||||
{
|
{
|
||||||
vec[i] = b[i];
|
vector[i] = b[i];
|
||||||
} else {
|
} else {
|
||||||
vec[i] = a[i];
|
vector[i] = a[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return vec;
|
return vector;
|
||||||
}
|
}
|
||||||
template<typename step_t> // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
|
template<typename step_t> // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
|
||||||
static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
|
static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
|
||||||
Vectorized vec;
|
Vectorized vector;
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
vec.values[i] = base + i * step;
|
vector.values[i] = base + i * step;
|
||||||
}
|
}
|
||||||
return vec;
|
return vector;
|
||||||
}
|
}
|
||||||
static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) {
|
static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) {
|
||||||
Vectorized vec;
|
Vectorized vector;
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
for (int64_t i = 0; i < size(); i++) {
|
||||||
if (i < count) {
|
if (i < count) {
|
||||||
vec[i] = b[i];
|
vector[i] = b[i];
|
||||||
} else {
|
} else {
|
||||||
vec[i] = a[i];
|
vector[i] = a[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return vec;
|
return vector;
|
||||||
}
|
}
|
||||||
static Vectorized<T> loadu(const void* ptr) {
|
static Vectorized<T> loadu(const void* ptr) {
|
||||||
Vectorized vec;
|
Vectorized vector;
|
||||||
std::memcpy(vec.values, ptr, 32);
|
std::memcpy(vector.values, ptr, VECTOR_WIDTH);
|
||||||
return vec;
|
return vector;
|
||||||
}
|
}
|
||||||
static Vectorized<T> loadu(const void* ptr, int64_t count) {
|
static Vectorized<T> loadu(const void* ptr, int64_t count) {
|
||||||
Vectorized vec;
|
Vectorized vector;
|
||||||
std::memcpy(vec.values, ptr, count * sizeof(T));
|
std::memcpy(vector.values, ptr, count * sizeof(T));
|
||||||
return vec;
|
return vector;
|
||||||
}
|
}
|
||||||
void store(void* ptr, int count = size()) const {
|
void store(void* ptr, int count = size()) const {
|
||||||
std::memcpy(ptr, values, count * sizeof(T));
|
std::memcpy(ptr, values, count * sizeof(T));
|
||||||
|
|
@ -203,15 +218,15 @@ public:
|
||||||
return mask;
|
return mask;
|
||||||
}
|
}
|
||||||
Vectorized<T> isnan() const {
|
Vectorized<T> isnan() const {
|
||||||
Vectorized<T> vec;
|
Vectorized<T> vector;
|
||||||
for (int64_t i = 0; i != size(); i++) {
|
for (int64_t i = 0; i != size(); i++) {
|
||||||
if (_isnan(values[i])) {
|
if (_isnan(values[i])) {
|
||||||
std::memset(static_cast<void*>(vec.values + i), 0xFF, sizeof(T));
|
std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
|
||||||
} else {
|
} else {
|
||||||
std::memset(static_cast<void*>(vec.values + i), 0, sizeof(T));
|
std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return vec;
|
return vector;
|
||||||
}
|
}
|
||||||
Vectorized<T> map(T (*const f)(T)) const {
|
Vectorized<T> map(T (*const f)(T)) const {
|
||||||
Vectorized<T> ret;
|
Vectorized<T> ret;
|
||||||
|
|
@ -488,15 +503,15 @@ private:
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
|
inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
|
||||||
// All bits are set to 1 if the pred is true, otherwise 0.
|
// All bits are set to 1 if the pred is true, otherwise 0.
|
||||||
Vectorized<T> vec;
|
Vectorized<T> vector;
|
||||||
for (int64_t i = 0; i != size(); i++) {
|
for (int64_t i = 0; i != size(); i++) {
|
||||||
if (op(values[i], other.values[i])) {
|
if (op(values[i], other.values[i])) {
|
||||||
std::memset(static_cast<void*>(vec.values + i), 0xFF, sizeof(T));
|
std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
|
||||||
} else {
|
} else {
|
||||||
std::memset(static_cast<void*>(vec.values + i), 0, sizeof(T));
|
std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return vec;
|
return vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
@ -511,11 +526,11 @@ private:
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const {
|
inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const {
|
||||||
// 1 if the pred is true, otherwise 0.
|
// 1 if the pred is true, otherwise 0.
|
||||||
Vectorized<T> vec;
|
Vectorized<T> vector;
|
||||||
for (int i = 0; i != size(); ++ i) {
|
for (int i = 0; i != size(); ++ i) {
|
||||||
vec[i] = bool(op(values[i], other.values[i]));
|
vector[i] = bool(op(values[i], other.values[i]));
|
||||||
}
|
}
|
||||||
return vec;
|
return vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
@ -668,41 +683,62 @@ Vectorized<T> inline clamp_min(const Vectorized<T> &a, const Vectorized<T> &min_
|
||||||
|
|
||||||
struct Vectorizedi;
|
struct Vectorizedi;
|
||||||
|
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
|
||||||
|
|
||||||
template <class T, typename Op>
|
template <class T, typename Op>
|
||||||
static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
|
static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
|
||||||
__m256i buffer;
|
int_vector buffer;
|
||||||
__m256i a_buffer = _mm256_loadu_si256(reinterpret_cast<const __m256i*>((const T*)a));
|
#if defined(CPU_CAPABILITY_AVX2)
|
||||||
__m256i b_buffer = _mm256_loadu_si256(reinterpret_cast<const __m256i*>((const T*)b));
|
int_vector a_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a));
|
||||||
|
int_vector b_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b));
|
||||||
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
int_vector a_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a));
|
||||||
|
int_vector b_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b));
|
||||||
|
#endif
|
||||||
buffer = op(a_buffer, b_buffer);
|
buffer = op(a_buffer, b_buffer);
|
||||||
__at_align32__ T results[Vectorized<T>::size()];
|
__at_align__ T results[Vectorized<T>::size()];
|
||||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(results), buffer);
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX2)
|
||||||
|
_mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer);
|
||||||
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
_mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer);
|
||||||
|
#endif
|
||||||
return Vectorized<T>::loadu(results);
|
return Vectorized<T>::loadu(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
|
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
|
||||||
inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
|
inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||||
// We enclose _mm256_and_si256 with lambda because it is always_inline
|
// We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline
|
||||||
return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_and_si256(a, b); });
|
#if defined(CPU_CAPABILITY_AVX2)
|
||||||
|
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
|
||||||
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
|
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
|
||||||
inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
|
inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||||
// We enclose _mm256_or_si256 with lambda because it is always_inline
|
// We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline
|
||||||
return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_or_si256(a, b); });
|
#if defined(CPU_CAPABILITY_AVX2)
|
||||||
|
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
|
||||||
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
|
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
|
||||||
inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
|
inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||||
// We enclose _mm256_xor_si256 with lambda because it is always_inline
|
// We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline
|
||||||
return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_xor_si256(a, b); });
|
#if defined(CPU_CAPABILITY_AVX2)
|
||||||
|
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
|
||||||
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
template<class T, typename Op>
|
template<class T, typename Op>
|
||||||
static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
|
static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
|
||||||
static constexpr uint32_t element_no = 32 / sizeof(intmax_t);
|
static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
|
||||||
__at_align32__ intmax_t buffer[element_no];
|
__at_align__ intmax_t buffer[element_no];
|
||||||
const intmax_t *a_ptr = reinterpret_cast<const intmax_t*>((const T*) a);
|
const intmax_t *a_ptr = reinterpret_cast<const intmax_t*>((const T*) a);
|
||||||
const intmax_t *b_ptr = reinterpret_cast<const intmax_t*>((const T*) b);
|
const intmax_t *b_ptr = reinterpret_cast<const intmax_t*>((const T*) b);
|
||||||
for (uint32_t i = 0U; i < element_no; ++ i) {
|
for (uint32_t i = 0U; i < element_no; ++ i) {
|
||||||
|
|
@ -724,12 +760,12 @@ inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||||
return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
|
return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
|
||||||
|
|
||||||
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
|
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
|
||||||
inline Vectorized<T> operator~(const Vectorized<T>& a) {
|
inline Vectorized<T> operator~(const Vectorized<T>& a) {
|
||||||
Vectorized<T> ones; // All bits are 1
|
Vectorized<T> ones; // All bits are 1
|
||||||
memset((T*) ones, 0xFF, 32);
|
memset((T*) ones, 0xFF, VECTOR_WIDTH);
|
||||||
return a ^ ones;
|
return a ^ ones;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -802,7 +838,9 @@ inline mask_gather(const Vectorized<T>& src, T const* base_addr,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cast a given vector to another type without changing the bits representation.
|
// Cast a given vector to another type without changing the bits representation.
|
||||||
// So a Vec<double> of 256 bits containing all ones can be cast to a
|
// So a Vectorized<double> of 512 bits containing all ones can be cast to a
|
||||||
|
// Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative 1s).
|
||||||
|
// A Vec<double> of 256 bits containing all ones can be cast to a
|
||||||
// Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
|
// Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
|
||||||
namespace {
|
namespace {
|
||||||
// There is a struct here because we don't have static_if and I can't
|
// There is a struct here because we don't have static_if and I can't
|
||||||
|
|
@ -840,10 +878,16 @@ inline Vectorized<int_same_size_t<T>> convert_to_int_of_same_size(const Vectoriz
|
||||||
return Vectorized<int_same_size_t<T>>::loadu(static_cast<void*>(buffer));
|
return Vectorized<int_same_size_t<T>>::loadu(static_cast<void*>(buffer));
|
||||||
}
|
}
|
||||||
|
|
||||||
// E.g., inputs: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
|
// Example inputs for AVX512:
|
||||||
// b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
|
// a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
|
||||||
// returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
|
// b Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
|
||||||
// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
|
// returns:
|
||||||
|
// Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
|
||||||
|
// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
|
||||||
|
// Example inputs for AVX2: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
|
||||||
|
// b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
|
||||||
|
// returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
|
||||||
|
// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
|
inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
|
||||||
deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
|
deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||||
|
|
@ -866,8 +910,14 @@ deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// inverse operation of deinterleave2
|
// inverse operation of deinterleave2
|
||||||
// E.g., inputs: a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
|
// Example inputs for AVX512:
|
||||||
// b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
|
// a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
|
||||||
|
// b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
|
||||||
|
// returns, for AVX512:
|
||||||
|
// Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
|
||||||
|
// Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
|
||||||
|
// Example inputs for AVX2 : a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
|
||||||
|
// b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
|
||||||
// returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
|
// returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
|
||||||
// Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
|
// Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
@ -35,21 +35,8 @@
|
||||||
#include <mkl.h>
|
#include <mkl.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// [Note SSE-AVX transitions]
|
|
||||||
// There is a bug in Glibc2.23
|
|
||||||
// https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280. Calling zeroall
|
|
||||||
// when using AVX/AVX2 code resolves this.
|
|
||||||
#if defined(CPU_CAPABILITY_AVX) && defined(__GLIBC__) && __GLIBC_MINOR__ == 23
|
|
||||||
#define DL_RUNTIME_BUG(op, type_) \
|
|
||||||
using value_t = typename c10::scalar_value_type<type_>::type;\
|
|
||||||
volatile value_t x = (value_t)(1); \
|
|
||||||
x = std::op(x); \
|
|
||||||
_mm256_zeroall();
|
|
||||||
#define DL_RUNTIME_BUG_BFLOAT16() _mm256_zeroall();
|
|
||||||
#else
|
|
||||||
#define DL_RUNTIME_BUG(op, type_)
|
#define DL_RUNTIME_BUG(op, type_)
|
||||||
#define DL_RUNTIME_BUG_BFLOAT16()
|
#define DL_RUNTIME_BUG_BFLOAT16()
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace vml {
|
namespace vml {
|
||||||
|
|
@ -117,36 +104,36 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
|
||||||
}); \
|
}); \
|
||||||
}
|
}
|
||||||
|
|
||||||
IMPLEMENT_VML_BUG(abs)
|
IMPLEMENT_VML(abs)
|
||||||
IMPLEMENT_VML_BUG(acos)
|
IMPLEMENT_VML(acos)
|
||||||
IMPLEMENT_VML_BUG(asin)
|
IMPLEMENT_VML(asin)
|
||||||
IMPLEMENT_VML_BUG(atan)
|
IMPLEMENT_VML(atan)
|
||||||
IMPLEMENT_VML_BUG(ceil)
|
IMPLEMENT_VML(ceil)
|
||||||
IMPLEMENT_VML_BUG(cos)
|
IMPLEMENT_VML(cos)
|
||||||
// IMPLEMENT_VML_BUG(cosh)
|
// IMPLEMENT_VML_BUG(cosh)
|
||||||
IMPLEMENT_VML_BUG(erf)
|
IMPLEMENT_VML(erf)
|
||||||
IMPLEMENT_VML_BUG(erfc)
|
IMPLEMENT_VML(erfc)
|
||||||
IMPLEMENT_VML(erfinv)
|
IMPLEMENT_VML(erfinv)
|
||||||
IMPLEMENT_VML_BUG(exp)
|
IMPLEMENT_VML(exp)
|
||||||
IMPLEMENT_VML_BUG(expm1)
|
IMPLEMENT_VML(expm1)
|
||||||
IMPLEMENT_VML_BUG(floor)
|
IMPLEMENT_VML(floor)
|
||||||
IMPLEMENT_VML(i0)
|
IMPLEMENT_VML(i0)
|
||||||
IMPLEMENT_VML(i0e)
|
IMPLEMENT_VML(i0e)
|
||||||
IMPLEMENT_VML(reciprocal)
|
IMPLEMENT_VML(reciprocal)
|
||||||
IMPLEMENT_VML_BUG(log)
|
IMPLEMENT_VML(log)
|
||||||
IMPLEMENT_VML_BUG(log10)
|
IMPLEMENT_VML(log10)
|
||||||
IMPLEMENT_VML_BUG(log1p)
|
IMPLEMENT_VML(log1p)
|
||||||
IMPLEMENT_VML_BUG(log2)
|
IMPLEMENT_VML(log2)
|
||||||
IMPLEMENT_VML(neg)
|
IMPLEMENT_VML(neg)
|
||||||
IMPLEMENT_VML_BUG(sin)
|
IMPLEMENT_VML(sin)
|
||||||
// IMPLEMENT_VML_BUG(sinh)
|
// IMPLEMENT_VML_BUG(sinh)
|
||||||
IMPLEMENT_VML_BUG(sqrt)
|
IMPLEMENT_VML(sqrt)
|
||||||
IMPLEMENT_VML_BUG(round)
|
IMPLEMENT_VML(round)
|
||||||
IMPLEMENT_VML(rsqrt)
|
IMPLEMENT_VML(rsqrt)
|
||||||
IMPLEMENT_VML_BUG(tan)
|
IMPLEMENT_VML(tan)
|
||||||
IMPLEMENT_VML_BUG(tanh)
|
IMPLEMENT_VML(tanh)
|
||||||
IMPLEMENT_VML_BUG(trunc)
|
IMPLEMENT_VML(trunc)
|
||||||
IMPLEMENT_VML_BUG(lgamma)
|
IMPLEMENT_VML(lgamma)
|
||||||
|
|
||||||
|
|
||||||
#if AT_MKL_ENABLED() && !defined(__APPLE__)
|
#if AT_MKL_ENABLED() && !defined(__APPLE__)
|
||||||
|
|
|
||||||
|
|
@ -952,91 +952,109 @@ void lu_solve_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots) {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(cholesky_stub, DEFAULT, &cholesky_kernel);
|
REGISTER_ARCH_DISPATCH(cholesky_stub, DEFAULT, &cholesky_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(cholesky_stub, &cholesky_kernel);
|
REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel);
|
REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel);
|
REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel);
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);
|
REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
|
REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
|
REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
|
REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl);
|
REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(eig_stub, &eig_kernel_impl);
|
REGISTER_AVX512_DISPATCH(eig_stub, &eig_kernel_impl);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl);
|
REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(eig_stub, &eig_kernel_impl);
|
REGISTER_VSX_DISPATCH(eig_stub, &eig_kernel_impl);
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel);
|
REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
|
REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
|
REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
|
REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel);
|
REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
|
REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
|
REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
|
REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel);
|
REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(geqrf_stub, &geqrf_kernel);
|
REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel);
|
REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel);
|
REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel);
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl);
|
REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
|
REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
|
REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
|
REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel);
|
REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(ormqr_stub, &ormqr_kernel);
|
REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel);
|
REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel);
|
REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel);
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel);
|
REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(lstsq_stub, &lstsq_kernel);
|
REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel);
|
REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel);
|
REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel);
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel);
|
REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
|
REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
|
REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
|
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel);
|
REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel);
|
||||||
REGISTER_AVX_DISPATCH(lu_stub, &lu_kernel);
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_AVX512_DISPATCH(lu_stub, &lu_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel);
|
REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel);
|
REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel);
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel);
|
REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel);
|
||||||
REGISTER_AVX_DISPATCH(lu_solve_stub, &lu_solve_kernel);
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel);
|
REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel);
|
REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel);
|
||||||
|
|
||||||
}} // namespace at::native
|
}} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -16,12 +16,12 @@ static CPUCapability compute_cpu_capability() {
|
||||||
return CPUCapability::VSX;
|
return CPUCapability::VSX;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
if (strcmp(envar, "avx512") == 0) {
|
||||||
|
return CPUCapability::AVX512;
|
||||||
|
}
|
||||||
if (strcmp(envar, "avx2") == 0) {
|
if (strcmp(envar, "avx2") == 0) {
|
||||||
return CPUCapability::AVX2;
|
return CPUCapability::AVX2;
|
||||||
}
|
}
|
||||||
if (strcmp(envar, "avx") == 0) {
|
|
||||||
return CPUCapability::AVX;
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
if (strcmp(envar, "default") == 0) {
|
if (strcmp(envar, "default") == 0) {
|
||||||
return CPUCapability::DEFAULT;
|
return CPUCapability::DEFAULT;
|
||||||
|
|
@ -31,12 +31,13 @@ static CPUCapability compute_cpu_capability() {
|
||||||
|
|
||||||
#if !defined(__powerpc__) && !defined(__s390x__)
|
#if !defined(__powerpc__) && !defined(__s390x__)
|
||||||
if (cpuinfo_initialize()) {
|
if (cpuinfo_initialize()) {
|
||||||
|
if (cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && \
|
||||||
|
cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_fma3()) {
|
||||||
|
return CPUCapability::AVX512;
|
||||||
|
}
|
||||||
if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) {
|
if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) {
|
||||||
return CPUCapability::AVX2;
|
return CPUCapability::AVX2;
|
||||||
}
|
}
|
||||||
if (cpuinfo_has_x86_avx()) {
|
|
||||||
return CPUCapability::AVX;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_VSX_CPU_DEFINITION
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
||||||
|
|
@ -54,8 +55,8 @@ CPUCapability get_cpu_capability() {
|
||||||
void* DispatchStubImpl::get_call_ptr(
|
void* DispatchStubImpl::get_call_ptr(
|
||||||
DeviceType device_type
|
DeviceType device_type
|
||||||
, void *DEFAULT
|
, void *DEFAULT
|
||||||
#ifdef HAVE_AVX_CPU_DEFINITION
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
||||||
, void *AVX
|
, void *AVX512
|
||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_AVX2_CPU_DEFINITION
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
||||||
, void *AVX2
|
, void *AVX2
|
||||||
|
|
@ -72,8 +73,8 @@ void* DispatchStubImpl::get_call_ptr(
|
||||||
if (!fptr) {
|
if (!fptr) {
|
||||||
fptr = choose_cpu_impl(
|
fptr = choose_cpu_impl(
|
||||||
DEFAULT
|
DEFAULT
|
||||||
#ifdef HAVE_AVX_CPU_DEFINITION
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
||||||
, AVX
|
, AVX512
|
||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_AVX2_CPU_DEFINITION
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
||||||
, AVX2
|
, AVX2
|
||||||
|
|
@ -102,8 +103,8 @@ void* DispatchStubImpl::get_call_ptr(
|
||||||
|
|
||||||
void* DispatchStubImpl::choose_cpu_impl(
|
void* DispatchStubImpl::choose_cpu_impl(
|
||||||
void *DEFAULT
|
void *DEFAULT
|
||||||
#ifdef HAVE_AVX_CPU_DEFINITION
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
||||||
, void *AVX
|
, void *AVX512
|
||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_AVX2_CPU_DEFINITION
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
||||||
, void *AVX2
|
, void *AVX2
|
||||||
|
|
@ -114,18 +115,26 @@ void* DispatchStubImpl::choose_cpu_impl(
|
||||||
) {
|
) {
|
||||||
auto capability = static_cast<int>(get_cpu_capability());
|
auto capability = static_cast<int>(get_cpu_capability());
|
||||||
(void)capability;
|
(void)capability;
|
||||||
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
||||||
|
if (capability >= static_cast<int>(CPUCapability::AVX512)) {
|
||||||
|
// Quantization kernels have also been disabled on Windows
|
||||||
|
// for AVX512 because some of their tests are flaky on Windows.
|
||||||
|
// Ideally, we should have AVX512 kernels for all kernels.
|
||||||
|
if (C10_UNLIKELY(!AVX512)) {
|
||||||
|
// dispatch to AVX2, since the AVX512 kernel is missing
|
||||||
|
TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
|
||||||
|
return AVX2;
|
||||||
|
} else {
|
||||||
|
return AVX512;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#ifdef HAVE_AVX2_CPU_DEFINITION
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
||||||
if (capability >= static_cast<int>(CPUCapability::AVX2)) {
|
if (capability >= static_cast<int>(CPUCapability::AVX2)) {
|
||||||
TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
|
TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
|
||||||
return AVX2;
|
return AVX2;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_AVX_CPU_DEFINITION
|
|
||||||
if (capability >= static_cast<int>(CPUCapability::AVX)) {
|
|
||||||
TORCH_INTERNAL_ASSERT(AVX, "DispatchStub: missing AVX kernel");
|
|
||||||
return AVX;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#ifdef HAVE_VSX_CPU_DEFINITION
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
||||||
if (capability >= static_cast<int>(CPUCapability::VSX)) {
|
if (capability >= static_cast<int>(CPUCapability::VSX)) {
|
||||||
TORCH_INTERNAL_ASSERT(VSX, "DispatchStub: missing VSX kernel");
|
TORCH_INTERNAL_ASSERT(VSX, "DispatchStub: missing VSX kernel");
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@
|
||||||
|
|
||||||
// Implements instruction set specific function dispatch.
|
// Implements instruction set specific function dispatch.
|
||||||
//
|
//
|
||||||
// Kernels that may make use of specialized instruction sets (e.g. AVX) are
|
// Kernels that may make use of specialized instruction sets (e.g. AVX2) are
|
||||||
// compiled multiple times with different compiler flags (e.g. -mavx). A
|
// compiled multiple times with different compiler flags (e.g. -mavx2). A
|
||||||
// DispatchStub contains a table of function pointers for a kernel. At runtime,
|
// DispatchStub contains a table of function pointers for a kernel. At runtime,
|
||||||
// the fastest available kernel is chosen based on the features reported by
|
// the fastest available kernel is chosen based on the features reported by
|
||||||
// cpuinfo.
|
// cpuinfo.
|
||||||
|
|
@ -50,8 +50,8 @@ enum class CPUCapability {
|
||||||
#ifdef HAVE_VSX_CPU_DEFINITION
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
||||||
VSX = 1,
|
VSX = 1,
|
||||||
#else
|
#else
|
||||||
AVX = 1,
|
AVX2 = 1,
|
||||||
AVX2 = 2,
|
AVX512 = 2,
|
||||||
#endif
|
#endif
|
||||||
NUM_OPTIONS
|
NUM_OPTIONS
|
||||||
};
|
};
|
||||||
|
|
@ -71,8 +71,8 @@ struct TORCH_API DispatchStubImpl {
|
||||||
void* get_call_ptr(
|
void* get_call_ptr(
|
||||||
DeviceType device_type
|
DeviceType device_type
|
||||||
, void *DEFAULT
|
, void *DEFAULT
|
||||||
#ifdef HAVE_AVX_CPU_DEFINITION
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
||||||
, void *AVX
|
, void *AVX512
|
||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_AVX2_CPU_DEFINITION
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
||||||
, void *AVX2
|
, void *AVX2
|
||||||
|
|
@ -89,8 +89,8 @@ struct TORCH_API DispatchStubImpl {
|
||||||
*/
|
*/
|
||||||
void* choose_cpu_impl(
|
void* choose_cpu_impl(
|
||||||
void *DEFAULT
|
void *DEFAULT
|
||||||
#ifdef HAVE_AVX_CPU_DEFINITION
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
||||||
, void *AVX
|
, void *AVX512
|
||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_AVX2_CPU_DEFINITION
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
||||||
, void *AVX2
|
, void *AVX2
|
||||||
|
|
@ -126,8 +126,8 @@ private:
|
||||||
return reinterpret_cast<FnPtr>(
|
return reinterpret_cast<FnPtr>(
|
||||||
impl.get_call_ptr(device_type
|
impl.get_call_ptr(device_type
|
||||||
, reinterpret_cast<void*>(DEFAULT)
|
, reinterpret_cast<void*>(DEFAULT)
|
||||||
#ifdef HAVE_AVX_CPU_DEFINITION
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
||||||
, reinterpret_cast<void*>(AVX)
|
, reinterpret_cast<void*>(AVX512)
|
||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_AVX2_CPU_DEFINITION
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
||||||
, reinterpret_cast<void*>(AVX2)
|
, reinterpret_cast<void*>(AVX2)
|
||||||
|
|
@ -155,8 +155,8 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
static FnPtr DEFAULT;
|
static FnPtr DEFAULT;
|
||||||
#ifdef HAVE_AVX_CPU_DEFINITION
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
||||||
static FnPtr AVX;
|
static FnPtr AVX512;
|
||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_AVX2_CPU_DEFINITION
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
||||||
static FnPtr AVX2;
|
static FnPtr AVX2;
|
||||||
|
|
@ -203,10 +203,10 @@ struct RegisterHIPDispatch {
|
||||||
#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
|
#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
|
||||||
template <> decltype(fn) DispatchStub<decltype(fn), struct name>::arch = fn;
|
template <> decltype(fn) DispatchStub<decltype(fn), struct name>::arch = fn;
|
||||||
|
|
||||||
#ifdef HAVE_AVX_CPU_DEFINITION
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
||||||
#define REGISTER_AVX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX, fn)
|
#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
|
||||||
#else
|
#else
|
||||||
#define REGISTER_AVX_DISPATCH(name, fn)
|
#define REGISTER_AVX512_DISPATCH(name, fn)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef HAVE_AVX2_CPU_DEFINITION
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
||||||
|
|
@ -223,8 +223,8 @@ struct RegisterHIPDispatch {
|
||||||
|
|
||||||
#define REGISTER_NO_CPU_DISPATCH(name, fn_type) \
|
#define REGISTER_NO_CPU_DISPATCH(name, fn_type) \
|
||||||
REGISTER_ARCH_DISPATCH(name, DEFAULT, static_cast<fn_type>(nullptr)) \
|
REGISTER_ARCH_DISPATCH(name, DEFAULT, static_cast<fn_type>(nullptr)) \
|
||||||
REGISTER_AVX_DISPATCH(name, static_cast<fn_type>(nullptr)) \
|
REGISTER_AVX512_DISPATCH(name, static_cast<fn_type>(nullptr)) \
|
||||||
REGISTER_AVX2_DISPATCH(name, static_cast<fn_type>(nullptr)) \
|
REGISTER_AVX2_DISPATCH(name, static_cast<fn_type>(nullptr)) \
|
||||||
REGISTER_VSX_DISPATCH(name, static_cast<fn_type>(nullptr))
|
REGISTER_VSX_DISPATCH(name, static_cast<fn_type>(nullptr))
|
||||||
|
|
||||||
#define REGISTER_CUDA_DISPATCH(name, fn) \
|
#define REGISTER_CUDA_DISPATCH(name, fn) \
|
||||||
|
|
@ -244,6 +244,8 @@ struct RegisterHIPDispatch {
|
||||||
// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
|
// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
|
||||||
#elif defined(CPU_CAPABILITY)
|
#elif defined(CPU_CAPABILITY)
|
||||||
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
||||||
|
#define REGISTER_NO_AVX512_DISPATCH(name, fn_type) \
|
||||||
|
REGISTER_AVX512_DISPATCH(name, static_cast<fn_type>(nullptr))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -275,10 +275,10 @@ REGISTER_ARCH_DISPATCH(
|
||||||
DEFAULT,
|
DEFAULT,
|
||||||
&_segment_reduce_cpu_kernel);
|
&_segment_reduce_cpu_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
||||||
REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
|
REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_AVX512_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
|
REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
|
||||||
|
|
||||||
// Currently some computation is being duplicated across forward and backward.
|
// Currently some computation is being duplicated across forward and backward.
|
||||||
|
|
@ -319,7 +319,7 @@ REGISTER_ARCH_DISPATCH(
|
||||||
DEFAULT,
|
DEFAULT,
|
||||||
&_segment_reduce_cpu_backward_kernel);
|
&_segment_reduce_cpu_backward_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_AVX_DISPATCH(
|
REGISTER_AVX512_DISPATCH(
|
||||||
_segment_reduce_backward_stub,
|
_segment_reduce_backward_stub,
|
||||||
&_segment_reduce_cpu_backward_kernel);
|
&_segment_reduce_cpu_backward_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ The most important things to know:
|
||||||
compiled multiple times for different instruction sets.** Yes,
|
compiled multiple times for different instruction sets.** Yes,
|
||||||
this folder is named `cpu`, but that doesn't mean put any old
|
this folder is named `cpu`, but that doesn't mean put any old
|
||||||
CPU kernel it. Only put CPU kernels which need to be compiled
|
CPU kernel it. Only put CPU kernels which need to be compiled
|
||||||
multiple times to take advantage of AVX/SSE instructions, but
|
multiple times to take advantage of AVX512/AVX2/SSE instructions, but
|
||||||
only on processors that support them.
|
only on processors that support them.
|
||||||
|
|
||||||
**Ensure that all implementations in this folder are put in an
|
**Ensure that all implementations in this folder are put in an
|
||||||
|
|
@ -52,14 +52,14 @@ All of the `*.cpp` files in this folder will be compiled under all compiler
|
||||||
flags specified by `CPU_CAPABILITY_FLAGS` in `aten/src/ATen/CMakeLists.txt`.
|
flags specified by `CPU_CAPABILITY_FLAGS` in `aten/src/ATen/CMakeLists.txt`.
|
||||||
|
|
||||||
The purpose of this is to allow the compilation with various compiler
|
The purpose of this is to allow the compilation with various compiler
|
||||||
flags to enable features such as AVX instructions, while using runtime
|
flags to enable features such as AVX2 or AVX512 instructions, while using
|
||||||
dispatch, which makes sure only valid instructions will be used on any
|
runtime dispatch, which makes sure only valid instructions will be used on any
|
||||||
given platform.
|
given platform.
|
||||||
|
|
||||||
Vectorized.h provides a generic implementation of a vec type that allows
|
vec.h provides a generic implementation of vec type that allows
|
||||||
the programmer to write code packing various primitives (such as floats)
|
the programmer to write code packing various primitives (such as floats)
|
||||||
within 256bit registers. vec defines various operators such as + and *
|
within 256bit & 512bits registers. vec defines various operators such as
|
||||||
and provides functions to allow operations such as max, min, etc.
|
+ and * and provides functions to allow operations such as max, min, etc.
|
||||||
|
|
||||||
As an example `ReduceOpsKernel.cpp` implements a generic `kernel_` that reduces
|
As an example `ReduceOpsKernel.cpp` implements a generic `kernel_` that reduces
|
||||||
an entire array using a given associative binary operation such as +.
|
an entire array using a given associative binary operation such as +.
|
||||||
|
|
@ -74,5 +74,5 @@ generic code, which will be compiled under multipled compilation settings.
|
||||||
`../ReduceOps.cpp` now includes the header `ReduceOpsKernel.h`, which contains
|
`../ReduceOps.cpp` now includes the header `ReduceOpsKernel.h`, which contains
|
||||||
a generic definition of `sumImplAll`. This function allows the user to reduce
|
a generic definition of `sumImplAll`. This function allows the user to reduce
|
||||||
over a dimension or all dimensions. The appropiate capability is chosen at
|
over a dimension or all dimensions. The appropiate capability is chosen at
|
||||||
runtime using cpuinfo. If the current platform has AVX, `sumImpl` will be set
|
runtime using cpuinfo. If the current platform has AVX2, `sumImpl` will be set
|
||||||
to `sumImplAll<CPUCapability::AVX>`.
|
to `sumImplAll<CPUCapability::AVX2>`.
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,8 @@ static inline bool is_outer_reduction(const int64_t* strides) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename func_t, typename vec_func_t>
|
template <typename func_t, typename vec_func_t>
|
||||||
static inline void reduction128(char** data, int64_t n, int64_t stride, func_t op, vec_func_t vop, bool reduce) {
|
static inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
|
||||||
|
func_t op, vec_func_t vop, bool reduce) {
|
||||||
VEC_LOOP_HEADER(func_t, data)
|
VEC_LOOP_HEADER(func_t, data)
|
||||||
const char* in1_ptr = data[1];
|
const char* in1_ptr = data[1];
|
||||||
Vec acc[4];
|
Vec acc[4];
|
||||||
|
|
@ -80,7 +81,7 @@ static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op,
|
||||||
int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
|
int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
|
||||||
int64_t count = n / (4 * Vec::size());
|
int64_t count = n / (4 * Vec::size());
|
||||||
if (count > 0) {
|
if (count > 0) {
|
||||||
reduction128(data, count, vector_stride, op, vop, /*reduce=*/true);
|
vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true);
|
||||||
}
|
}
|
||||||
char* ptrs[3] = { data[0], data[0], data[1] };
|
char* ptrs[3] = { data[0], data[0], data[1] };
|
||||||
int64_t strides[] = { 0, 0, sizeof(scalar_t) };
|
int64_t strides[] = { 0, 0, sizeof(scalar_t) };
|
||||||
|
|
@ -92,10 +93,14 @@ template <typename func_t, typename vec_func_t>
|
||||||
static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
|
static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
|
||||||
VEC_LOOP_HEADER(func_t, data)
|
VEC_LOOP_HEADER(func_t, data)
|
||||||
|
|
||||||
// reduce down each column of 4 * Vec::size() elements (128 bytes)
|
// reduce down each column of 4 * Vec::size() elements (128 or 256 bytes)
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
int64_t outer_stride[2] = { 256, 256 };
|
||||||
|
#else
|
||||||
int64_t outer_stride[2] = { 128, 128 };
|
int64_t outer_stride[2] = { 128, 128 };
|
||||||
|
#endif
|
||||||
UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
|
UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
|
||||||
reduction128(data, size0, inner_stride, op, vop, /*reduce=*/false);
|
vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false);
|
||||||
});
|
});
|
||||||
|
|
||||||
// reduce down the remaining columns
|
// reduce down the remaining columns
|
||||||
|
|
|
||||||
|
|
@ -219,9 +219,15 @@ inline void _vec_softmax(
|
||||||
int64_t outer_stride = dim_size * dim_stride;
|
int64_t outer_stride = dim_size * dim_stride;
|
||||||
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
|
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
|
||||||
int vectorized_step = Vec().size(); // Currently, we only support scalar_t with double or float32
|
int vectorized_step = Vec().size(); // Currently, we only support scalar_t with double or float32
|
||||||
TORCH_CHECK(
|
#ifdef CPU_CAPABILITY_AVX512
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
(vectorized_step == 16) || (vectorized_step == 8),
|
||||||
|
"vectorized_step must be 16 with dtype float or 8 with dtype double");
|
||||||
|
#else
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
(vectorized_step == 8) || (vectorized_step == 4),
|
(vectorized_step == 8) || (vectorized_step == 4),
|
||||||
"vectorized_step must be 8 with dtype float or 4 with dtype double");
|
"vectorized_step must be 8 with dtype float or 4 with dtype double");
|
||||||
|
#endif
|
||||||
parallel_for(
|
parallel_for(
|
||||||
0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
|
0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
|
||||||
int64_t idx = begin;
|
int64_t idx = begin;
|
||||||
|
|
|
||||||
|
|
@ -611,8 +611,15 @@ void nansum_kernel_impl(TensorIterator &iter) {
|
||||||
|
|
||||||
} // namespace (anonymous)
|
} // namespace (anonymous)
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
||||||
REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
|
REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
|
||||||
|
|
||||||
|
// nansum on Float16 has poor accuracy with AVX2, and more so with AVX512.
|
||||||
|
// So until it's fixed, it won't be dispatched with AVX512. GH issue 59415.
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
#ifndef CPU_CAPABILITY_AVX512
|
||||||
REGISTER_DISPATCH(nansum_stub, &nansum_kernel_impl);
|
REGISTER_DISPATCH(nansum_stub, &nansum_kernel_impl);
|
||||||
|
#else
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(nansum_stub, reduce_fn);
|
||||||
|
#endif
|
||||||
|
|
||||||
}} // namespace at::native
|
}} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -715,8 +715,15 @@ REGISTER_DISPATCH(exponential_stub, &CPU_CAPABILITY::exponential_kernel);
|
||||||
REGISTER_DISPATCH(geometric_stub, &CPU_CAPABILITY::geometric_kernel);
|
REGISTER_DISPATCH(geometric_stub, &CPU_CAPABILITY::geometric_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_DISPATCH(log_normal_stub, &CPU_CAPABILITY::log_normal_kernel);
|
REGISTER_DISPATCH(log_normal_stub, &CPU_CAPABILITY::log_normal_kernel);
|
||||||
|
#ifdef CPU_CAPABILITY_AVX512
|
||||||
|
// normal_stub isn't being dispatched to AVX512 because it exposes
|
||||||
|
// flakiness in test_sgd of test/test_optim.py
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(normal_stub, void(*)(Tensor&, const double, const double, c10::optional<Generator>));
|
||||||
|
#else
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_DISPATCH(normal_stub, &CPU_CAPABILITY::normal_kernel);
|
REGISTER_DISPATCH(normal_stub, &CPU_CAPABILITY::normal_kernel);
|
||||||
|
#endif
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_DISPATCH(uniform_stub, &CPU_CAPABILITY::uniform_kernel);
|
REGISTER_DISPATCH(uniform_stub, &CPU_CAPABILITY::uniform_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
|
|
||||||
|
|
@ -32,26 +32,17 @@
|
||||||
|
|
||||||
#include <ATen/native/cpu/Intrinsics.h>
|
#include <ATen/native/cpu/Intrinsics.h>
|
||||||
|
|
||||||
/* yes I know, the top of this file is quite ugly */
|
/* The original source of this file has been modified. */
|
||||||
|
#if defined(CPU_CAPABILITY_AVX2)
|
||||||
|
|
||||||
#if defined(__GNUC__)
|
#if defined(__GNUC__)
|
||||||
# define ALIGN32_BEG __attribute__((aligned(32)))
|
# define ALIGN32_BEG __attribute__((aligned(32)))
|
||||||
#elif defined(_WIN32)
|
#elif defined(_WIN32)
|
||||||
# define ALIGN32_BEG __declspec(align(32))
|
# define ALIGN32_BEG __declspec(align(32))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/* __m128 is ugly to write */
|
typedef __m256 v8sf; // vector of 8 float (avx2)
|
||||||
typedef __m256 v8sf; // vector of 8 float (avx)
|
typedef __m256i v8si; // vector of 8 int (avx2)
|
||||||
typedef __m256i v8si; // vector of 8 int (avx)
|
|
||||||
typedef __m128i v4si; // vector of 8 int (avx)
|
|
||||||
|
|
||||||
#define _PI32AVX_CONST(Name, Val) \
|
|
||||||
static const ALIGN32_BEG int _pi32avx_##Name[4] = { Val, Val, Val, Val }
|
|
||||||
|
|
||||||
_PI32AVX_CONST(1, 1);
|
|
||||||
_PI32AVX_CONST(inv1, ~1);
|
|
||||||
_PI32AVX_CONST(2, 2);
|
|
||||||
_PI32AVX_CONST(4, 4);
|
|
||||||
|
|
||||||
|
|
||||||
/* declare some AVX constants -- why can't I figure a better way to do that? */
|
/* declare some AVX constants -- why can't I figure a better way to do that? */
|
||||||
#define _PS256_CONST(Name, Val) \
|
#define _PS256_CONST(Name, Val) \
|
||||||
|
|
@ -91,67 +82,6 @@ _PS256_CONST(cephes_log_p8, + 3.3333331174E-1);
|
||||||
_PS256_CONST(cephes_log_q1, -2.12194440e-4);
|
_PS256_CONST(cephes_log_q1, -2.12194440e-4);
|
||||||
_PS256_CONST(cephes_log_q2, 0.693359375);
|
_PS256_CONST(cephes_log_q2, 0.693359375);
|
||||||
|
|
||||||
#ifndef CPU_CAPABILITY_AVX2
|
|
||||||
|
|
||||||
typedef union imm_xmm_union {
|
|
||||||
v8si imm;
|
|
||||||
v4si xmm[2];
|
|
||||||
} imm_xmm_union;
|
|
||||||
|
|
||||||
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) { \
|
|
||||||
imm_xmm_union u __attribute__((aligned(32))); \
|
|
||||||
u.imm = imm_; \
|
|
||||||
xmm0_ = u.xmm[0]; \
|
|
||||||
xmm1_ = u.xmm[1]; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) { \
|
|
||||||
imm_xmm_union u __attribute__((aligned(32))); \
|
|
||||||
u.xmm[0]=xmm0_; u.xmm[1]=xmm1_; imm_ = u.imm; \
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#define AVX2_BITOP_USING_SSE2(fn) \
|
|
||||||
static inline v8si _mm256_##fn(v8si x, int a) \
|
|
||||||
{ \
|
|
||||||
/* use SSE2 instruction to perform the bitop AVX2 */ \
|
|
||||||
v4si x1, x2; \
|
|
||||||
v8si ret; \
|
|
||||||
COPY_IMM_TO_XMM(x, x1, x2); \
|
|
||||||
x1 = _mm_##fn(x1,a); \
|
|
||||||
x2 = _mm_##fn(x2,a); \
|
|
||||||
COPY_XMM_TO_IMM(x1, x2, ret); \
|
|
||||||
return(ret); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#warning "Using SSE2 to perform AVX2 bitshift ops"
|
|
||||||
AVX2_BITOP_USING_SSE2(slli_epi32)
|
|
||||||
AVX2_BITOP_USING_SSE2(srli_epi32)
|
|
||||||
|
|
||||||
#define AVX2_INTOP_USING_SSE2(fn) \
|
|
||||||
static inline v8si _mm256_##fn(v8si x, v8si y) \
|
|
||||||
{ \
|
|
||||||
/* use SSE2 instructions to perform the AVX2 integer operation */ \
|
|
||||||
v4si x1, x2; \
|
|
||||||
v4si y1, y2; \
|
|
||||||
v8si ret; \
|
|
||||||
COPY_IMM_TO_XMM(x, x1, x2); \
|
|
||||||
COPY_IMM_TO_XMM(y, y1, y2); \
|
|
||||||
x1 = _mm_##fn(x1,y1); \
|
|
||||||
x2 = _mm_##fn(x2,y2); \
|
|
||||||
COPY_XMM_TO_IMM(x1, x2, ret); \
|
|
||||||
return(ret); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#warning "Using SSE2 to perform AVX2 integer ops"
|
|
||||||
AVX2_INTOP_USING_SSE2(and_si128)
|
|
||||||
AVX2_INTOP_USING_SSE2(andnot_si128)
|
|
||||||
AVX2_INTOP_USING_SSE2(cmpeq_epi32)
|
|
||||||
AVX2_INTOP_USING_SSE2(sub_epi32)
|
|
||||||
AVX2_INTOP_USING_SSE2(add_epi32)
|
|
||||||
|
|
||||||
#endif /* CPU_CAPABILITY_AVX2 */
|
|
||||||
|
|
||||||
|
|
||||||
/* natural logarithm computed for 8 simultaneous float
|
/* natural logarithm computed for 8 simultaneous float
|
||||||
return NaN for x <= 0
|
return NaN for x <= 0
|
||||||
|
|
@ -326,11 +256,6 @@ inline v8sf sin256_ps(v8sf x) { // any x
|
||||||
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
|
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
|
||||||
v8si imm0, imm2;
|
v8si imm0, imm2;
|
||||||
|
|
||||||
#ifndef CPU_CAPABILITY_AVX2
|
|
||||||
v4si imm0_1, imm0_2;
|
|
||||||
v4si imm2_1, imm2_2;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
sign_bit = x;
|
sign_bit = x;
|
||||||
/* take the absolute value */
|
/* take the absolute value */
|
||||||
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
||||||
|
|
@ -346,7 +271,6 @@ inline v8sf sin256_ps(v8sf x) { // any x
|
||||||
If we don't have AVX, let's perform them using SSE2 directives
|
If we don't have AVX, let's perform them using SSE2 directives
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
/* store the integer part of y in mm0 */
|
/* store the integer part of y in mm0 */
|
||||||
imm2 = _mm256_cvttps_epi32(y);
|
imm2 = _mm256_cvttps_epi32(y);
|
||||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||||
|
|
@ -366,35 +290,6 @@ inline v8sf sin256_ps(v8sf x) { // any x
|
||||||
*/
|
*/
|
||||||
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
||||||
imm2 = _mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0);
|
imm2 = _mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0);
|
||||||
#else
|
|
||||||
/* we use SSE2 routines to perform the integer ops */
|
|
||||||
COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
|
|
||||||
|
|
||||||
imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
|
|
||||||
imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
|
|
||||||
|
|
||||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
|
|
||||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
|
|
||||||
|
|
||||||
COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
|
|
||||||
y = _mm256_cvtepi32_ps(imm2);
|
|
||||||
|
|
||||||
imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4);
|
|
||||||
imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4);
|
|
||||||
|
|
||||||
imm0_1 = _mm_slli_epi32(imm0_1, 29);
|
|
||||||
imm0_2 = _mm_slli_epi32(imm0_2, 29);
|
|
||||||
|
|
||||||
COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
|
|
||||||
|
|
||||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
|
|
||||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
|
|
||||||
|
|
||||||
imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
|
|
||||||
imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
|
|
||||||
|
|
||||||
COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
|
v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
|
||||||
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
||||||
|
|
@ -453,18 +348,12 @@ inline v8sf cos256_ps(v8sf x) { // any x
|
||||||
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
|
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
|
||||||
v8si imm0, imm2;
|
v8si imm0, imm2;
|
||||||
|
|
||||||
#ifndef CPU_CAPABILITY_AVX2
|
|
||||||
v4si imm0_1, imm0_2;
|
|
||||||
v4si imm2_1, imm2_2;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/* take the absolute value */
|
/* take the absolute value */
|
||||||
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
||||||
|
|
||||||
/* scale by 4/Pi */
|
/* scale by 4/Pi */
|
||||||
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
||||||
|
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
/* store the integer part of y in mm0 */
|
/* store the integer part of y in mm0 */
|
||||||
imm2 = _mm256_cvttps_epi32(y);
|
imm2 = _mm256_cvttps_epi32(y);
|
||||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||||
|
|
@ -479,39 +368,6 @@ inline v8sf cos256_ps(v8sf x) { // any x
|
||||||
/* get the polynom selection mask */
|
/* get the polynom selection mask */
|
||||||
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
||||||
imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
|
imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
|
||||||
#else
|
|
||||||
|
|
||||||
/* we use SSE2 routines to perform the integer ops */
|
|
||||||
COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
|
|
||||||
|
|
||||||
imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
|
|
||||||
imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
|
|
||||||
|
|
||||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
|
|
||||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
|
|
||||||
|
|
||||||
COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
|
|
||||||
y = _mm256_cvtepi32_ps(imm2);
|
|
||||||
|
|
||||||
imm2_1 = _mm_sub_epi32(imm2_1, *(v4si*)_pi32avx_2);
|
|
||||||
imm2_2 = _mm_sub_epi32(imm2_2, *(v4si*)_pi32avx_2);
|
|
||||||
|
|
||||||
imm0_1 = _mm_andnot_si128(imm2_1, *(v4si*)_pi32avx_4);
|
|
||||||
imm0_2 = _mm_andnot_si128(imm2_2, *(v4si*)_pi32avx_4);
|
|
||||||
|
|
||||||
imm0_1 = _mm_slli_epi32(imm0_1, 29);
|
|
||||||
imm0_2 = _mm_slli_epi32(imm0_2, 29);
|
|
||||||
|
|
||||||
COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
|
|
||||||
|
|
||||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
|
|
||||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
|
|
||||||
|
|
||||||
imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
|
|
||||||
imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
|
|
||||||
|
|
||||||
COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
v8sf sign_bit = _mm256_castsi256_ps(imm0);
|
v8sf sign_bit = _mm256_castsi256_ps(imm0);
|
||||||
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
||||||
|
|
@ -571,12 +427,6 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
|
||||||
v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
|
v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
|
||||||
v8si imm0, imm2, imm4;
|
v8si imm0, imm2, imm4;
|
||||||
|
|
||||||
#ifndef CPU_CAPABILITY_AVX2
|
|
||||||
v4si imm0_1, imm0_2;
|
|
||||||
v4si imm2_1, imm2_2;
|
|
||||||
v4si imm4_1, imm4_2;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
sign_bit_sin = x;
|
sign_bit_sin = x;
|
||||||
/* take the absolute value */
|
/* take the absolute value */
|
||||||
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
||||||
|
|
@ -586,7 +436,6 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
|
||||||
/* scale by 4/Pi */
|
/* scale by 4/Pi */
|
||||||
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
||||||
|
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
/* store the integer part of y in imm2 */
|
/* store the integer part of y in imm2 */
|
||||||
imm2 = _mm256_cvttps_epi32(y);
|
imm2 = _mm256_cvttps_epi32(y);
|
||||||
|
|
||||||
|
|
@ -606,38 +455,7 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
|
||||||
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
||||||
imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
|
imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
|
||||||
//v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
//v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
||||||
#else
|
|
||||||
/* we use SSE2 routines to perform the integer ops */
|
|
||||||
COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
|
|
||||||
|
|
||||||
imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
|
|
||||||
imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
|
|
||||||
|
|
||||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
|
|
||||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
|
|
||||||
|
|
||||||
COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
|
|
||||||
y = _mm256_cvtepi32_ps(imm2);
|
|
||||||
|
|
||||||
imm4_1 = imm2_1;
|
|
||||||
imm4_2 = imm2_2;
|
|
||||||
|
|
||||||
imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4);
|
|
||||||
imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4);
|
|
||||||
|
|
||||||
imm0_1 = _mm_slli_epi32(imm0_1, 29);
|
|
||||||
imm0_2 = _mm_slli_epi32(imm0_2, 29);
|
|
||||||
|
|
||||||
COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
|
|
||||||
|
|
||||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
|
|
||||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
|
|
||||||
|
|
||||||
imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
|
|
||||||
imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
|
|
||||||
|
|
||||||
COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
|
|
||||||
#endif
|
|
||||||
v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
|
v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
|
||||||
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
||||||
|
|
||||||
|
|
@ -653,22 +471,9 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
|
||||||
x = _mm256_add_ps(x, xmm2);
|
x = _mm256_add_ps(x, xmm2);
|
||||||
x = _mm256_add_ps(x, xmm3);
|
x = _mm256_add_ps(x, xmm3);
|
||||||
|
|
||||||
#ifdef CPU_CAPABILITY_AVX2
|
|
||||||
imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2);
|
imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2);
|
||||||
imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4);
|
imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4);
|
||||||
imm4 = _mm256_slli_epi32(imm4, 29);
|
imm4 = _mm256_slli_epi32(imm4, 29);
|
||||||
#else
|
|
||||||
imm4_1 = _mm_sub_epi32(imm4_1, *(v4si*)_pi32avx_2);
|
|
||||||
imm4_2 = _mm_sub_epi32(imm4_2, *(v4si*)_pi32avx_2);
|
|
||||||
|
|
||||||
imm4_1 = _mm_andnot_si128(imm4_1, *(v4si*)_pi32avx_4);
|
|
||||||
imm4_2 = _mm_andnot_si128(imm4_2, *(v4si*)_pi32avx_4);
|
|
||||||
|
|
||||||
imm4_1 = _mm_slli_epi32(imm4_1, 29);
|
|
||||||
imm4_2 = _mm_slli_epi32(imm4_2, 29);
|
|
||||||
|
|
||||||
COPY_XMM_TO_IMM(imm4_1, imm4_2, imm4);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);
|
v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);
|
||||||
|
|
||||||
|
|
@ -713,3 +518,5 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
|
||||||
*s = _mm256_xor_ps(xmm1, sign_bit_sin);
|
*s = _mm256_xor_ps(xmm1, sign_bit_sin);
|
||||||
*c = _mm256_xor_ps(xmm2, sign_bit_cos);
|
*c = _mm256_xor_ps(xmm2, sign_bit_cos);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif // CPU_CAPABILITY_AVX2
|
||||||
|
|
|
||||||
|
|
@ -149,8 +149,8 @@ static void _fft_fill_with_conjugate_symmetry_cpu_(
|
||||||
|
|
||||||
// Register this one implementation for all cpu types instead of compiling multiple times
|
// Register this one implementation for all cpu types instead of compiling multiple times
|
||||||
REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_)
|
REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||||
REGISTER_AVX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
|
||||||
REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||||
|
REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||||
|
|
||||||
// _out variants can be shared between PocketFFT and MKL
|
// _out variants can be shared between PocketFFT and MKL
|
||||||
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <ATen/native/UpSample.h>
|
#include <ATen/native/UpSample.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
#include <ATen/native/quantized/affine_quantizer.h>
|
#include <ATen/native/quantized/affine_quantizer.h>
|
||||||
|
#include <ATen/native/quantized/fake_quant_affine.h>
|
||||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
|
|
@ -197,7 +198,28 @@ int64_t hsum(const uint8_t* A, int len) {
|
||||||
for (const auto k : c10::irange(8)) {
|
for (const auto k : c10::irange(8)) {
|
||||||
row_sum += temp[k];
|
row_sum += temp[k];
|
||||||
}
|
}
|
||||||
#endif // CPU_CAPABILITY_AVX2
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
__m512i sum_v = _mm512_setzero_si512();
|
||||||
|
__m512i one_epi16_v = _mm512_set1_epi16(1);
|
||||||
|
__m512i one_epi8_v = _mm512_set1_epi8(1);
|
||||||
|
// vectorized
|
||||||
|
for (; i < len / 64 * 64; i += 64) {
|
||||||
|
__m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
|
||||||
|
sum_v = _mm512_add_epi32(
|
||||||
|
sum_v,
|
||||||
|
_mm512_madd_epi16(
|
||||||
|
// first argument is unsigned, second is signed
|
||||||
|
_mm512_maddubs_epi16(src_v, one_epi8_v),
|
||||||
|
one_epi16_v)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
alignas(64) int32_t temp[16];
|
||||||
|
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v);
|
||||||
|
for (const auto k : c10::irange(16)) {
|
||||||
|
row_sum += temp[k];
|
||||||
|
}
|
||||||
|
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
|
||||||
|
|
||||||
// scalar
|
// scalar
|
||||||
for (; i < len; ++i) {
|
for (; i < len; ++i) {
|
||||||
|
|
@ -233,7 +255,28 @@ int64_t hsum(const int8_t* A, int len) {
|
||||||
for (const auto k : c10::irange(8)) {
|
for (const auto k : c10::irange(8)) {
|
||||||
row_sum += temp[k];
|
row_sum += temp[k];
|
||||||
}
|
}
|
||||||
#endif // CPU_CAPABILITY_AVX2
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
__m512i sum_v = _mm512_setzero_si512();
|
||||||
|
__m512i one_epi16_v = _mm512_set1_epi16(1);
|
||||||
|
__m512i one_epi8_v = _mm512_set1_epi8(1);
|
||||||
|
// vectorized
|
||||||
|
for (; i < len / 64 * 64; i += 64) {
|
||||||
|
__m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
|
||||||
|
sum_v = _mm512_add_epi32(
|
||||||
|
sum_v,
|
||||||
|
_mm512_madd_epi16(
|
||||||
|
// first argument is unsigned, second is signed
|
||||||
|
_mm512_maddubs_epi16(one_epi8_v, src_v),
|
||||||
|
one_epi16_v)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
alignas(64) int32_t temp[16];
|
||||||
|
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v);
|
||||||
|
for (const auto k : c10::irange(16)) {
|
||||||
|
row_sum += temp[k];
|
||||||
|
}
|
||||||
|
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
|
||||||
|
|
||||||
// scalar
|
// scalar
|
||||||
for (; i < len; ++i) {
|
for (; i < len; ++i) {
|
||||||
|
|
@ -255,7 +298,7 @@ int64_t hsum(const int32_t* A, int len) {
|
||||||
__m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
|
__m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
|
||||||
// widen
|
// widen
|
||||||
__m128i src_lo_epi32 = _mm256_castsi256_si128(src_epi32);
|
__m128i src_lo_epi32 = _mm256_castsi256_si128(src_epi32);
|
||||||
__m128i src_hi_epi32 = _mm256_extractf128_si256(src_epi32, 1);
|
__m128i src_hi_epi32 = _mm256_extracti128_si256(src_epi32, 1);
|
||||||
__m256i src_lo_epi64 = _mm256_cvtepi32_epi64(src_lo_epi32);
|
__m256i src_lo_epi64 = _mm256_cvtepi32_epi64(src_lo_epi32);
|
||||||
__m256i src_hi_epi64 = _mm256_cvtepi32_epi64(src_hi_epi32);
|
__m256i src_hi_epi64 = _mm256_cvtepi32_epi64(src_hi_epi32);
|
||||||
// add
|
// add
|
||||||
|
|
@ -268,7 +311,27 @@ int64_t hsum(const int32_t* A, int len) {
|
||||||
for (const auto k : c10::irange(4)) {
|
for (const auto k : c10::irange(4)) {
|
||||||
row_sum += temp[k];
|
row_sum += temp[k];
|
||||||
}
|
}
|
||||||
#endif // CPU_CAPABILITY_AVX2
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
__m512i sum_epi64 = _mm512_setzero_si512();
|
||||||
|
// vectorized
|
||||||
|
for (; i < len / 16 * 16; i += 16) {
|
||||||
|
__m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
|
||||||
|
// widen
|
||||||
|
__m256i src_lo_epi32 = _mm512_castsi512_si256(src_epi32);
|
||||||
|
__m256i src_hi_epi32 = _mm512_extracti32x8_epi32(src_epi32, 1);
|
||||||
|
__m512i src_lo_epi64 = _mm512_cvtepi32_epi64(src_lo_epi32);
|
||||||
|
__m512i src_hi_epi64 = _mm512_cvtepi32_epi64(src_hi_epi32);
|
||||||
|
// add
|
||||||
|
sum_epi64 = _mm512_add_epi64(sum_epi64, src_lo_epi64);
|
||||||
|
sum_epi64 = _mm512_add_epi64(sum_epi64, src_hi_epi64);
|
||||||
|
}
|
||||||
|
|
||||||
|
alignas(64) int64_t temp[8];
|
||||||
|
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_epi64);
|
||||||
|
for (const auto k : c10::irange(8)) {
|
||||||
|
row_sum += temp[k];
|
||||||
|
}
|
||||||
|
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
|
||||||
|
|
||||||
// scalar
|
// scalar
|
||||||
for (; i < len; ++i) {
|
for (; i < len; ++i) {
|
||||||
|
|
@ -313,7 +376,36 @@ int64_t hsum_sq(const uint8_t* A, int len) {
|
||||||
}
|
}
|
||||||
sum_v_epu32 = _mm256_setzero_si256();
|
sum_v_epu32 = _mm256_setzero_si256();
|
||||||
}
|
}
|
||||||
#endif // CPU_CAPABILITY_AVX2
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
__m512i sum_v_epu32 = _mm512_setzero_si512();
|
||||||
|
alignas(64) int32_t temp[16];
|
||||||
|
int overflow_threshold = 262144; // 2147483647(max of int32)/(512*512)*8 = 262144
|
||||||
|
int loop = len / overflow_threshold + 1;
|
||||||
|
for(int j=0; j<=loop; j++){
|
||||||
|
for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) {
|
||||||
|
// (i31, ..., i0)
|
||||||
|
__m256i src_epu8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
|
||||||
|
__m512i src_epu16 = _mm512_cvtepu8_epi16(src_epu8);
|
||||||
|
// (i31 ^ 2, ..., i0 ^ 2)
|
||||||
|
__m512i sq_epu16 = _mm512_mullo_epi16(src_epu16, src_epu16);
|
||||||
|
// (i15 ^ 2, ..., i0 ^ 2)
|
||||||
|
__m256i sq_lo_epu16 = _mm512_castsi512_si256(sq_epu16);
|
||||||
|
// (i31 ^ 2, ..., i16 ^ 2)
|
||||||
|
__m256i sq_hi_epu16 = _mm512_extracti32x8_epi32(sq_epu16, 1);
|
||||||
|
// widen to epu32
|
||||||
|
__m512i sq_lo_epu32 = _mm512_cvtepu16_epi32(sq_lo_epu16);
|
||||||
|
__m512i sq_hi_epu32 = _mm512_cvtepu16_epi32(sq_hi_epu16);
|
||||||
|
// add to running sum
|
||||||
|
sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_lo_epu32);
|
||||||
|
sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_hi_epu32);
|
||||||
|
}
|
||||||
|
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epu32);
|
||||||
|
for (const auto k : c10::irange(16)) {
|
||||||
|
row_sum += temp[k];
|
||||||
|
}
|
||||||
|
sum_v_epu32 = _mm512_setzero_si512();
|
||||||
|
}
|
||||||
|
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
|
||||||
|
|
||||||
// scalar
|
// scalar
|
||||||
for (; i < len; ++i) {
|
for (; i < len; ++i) {
|
||||||
|
|
@ -361,7 +453,40 @@ int64_t hsum_sq(const int8_t* A, int len) {
|
||||||
}
|
}
|
||||||
sum_v_epi32 = _mm256_setzero_si256();
|
sum_v_epi32 = _mm256_setzero_si256();
|
||||||
}
|
}
|
||||||
#endif // CPU_CAPABILITY_AVX2
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
// vectorized
|
||||||
|
__m512i sum_v_epi32 = _mm512_setzero_si512();
|
||||||
|
alignas(64) int32_t temp[16];
|
||||||
|
|
||||||
|
int overflow_threshold = 1048576; //2147483647/(256*256)*8 = 1048576
|
||||||
|
int loop = len / overflow_threshold + 1;
|
||||||
|
|
||||||
|
for(int j=0; j<=loop; j++){
|
||||||
|
for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) {
|
||||||
|
// (i31, ..., i0)
|
||||||
|
__m256i src_epi8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
|
||||||
|
__m512i src_epi16 = _mm512_cvtepi8_epi16(src_epi8);
|
||||||
|
// (i31 ^ 2, ..., i0 ^ 2)
|
||||||
|
__m512i sq_epi16 = _mm512_mullo_epi16(src_epi16, src_epi16);
|
||||||
|
// (i15 ^ 2, ..., i0 ^ 2)
|
||||||
|
__m256i sq_lo_epi16 = _mm512_castsi512_si256(sq_epi16);
|
||||||
|
// (i31 ^ 2, ..., i16 ^ 2)
|
||||||
|
__m256i sq_hi_epi16 = _mm512_extracti32x8_epi32(sq_epi16, 1);
|
||||||
|
// widen to epi32
|
||||||
|
__m512i sq_lo_epi32 = _mm512_cvtepi16_epi32(sq_lo_epi16);
|
||||||
|
__m512i sq_hi_epi32 = _mm512_cvtepi16_epi32(sq_hi_epi16);
|
||||||
|
// add to running sum
|
||||||
|
sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_lo_epi32);
|
||||||
|
sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_hi_epi32);
|
||||||
|
}
|
||||||
|
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epi32);
|
||||||
|
|
||||||
|
for (const auto k : c10::irange(16)) {
|
||||||
|
row_sum += temp[k];
|
||||||
|
}
|
||||||
|
sum_v_epi32 = _mm512_setzero_si512();
|
||||||
|
}
|
||||||
|
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
|
||||||
|
|
||||||
// scalar
|
// scalar
|
||||||
for (; i < len; ++i) {
|
for (; i < len; ++i) {
|
||||||
|
|
@ -391,7 +516,21 @@ float hsum_sq(const int32_t* A, int len) {
|
||||||
for (const auto k : c10::irange(8)) {
|
for (const auto k : c10::irange(8)) {
|
||||||
row_sum += static_cast<float>(temp[k]);
|
row_sum += static_cast<float>(temp[k]);
|
||||||
}
|
}
|
||||||
#endif // CPU_CAPABILITY_AVX2
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
|
__m512 sum_ps = _mm512_setzero_ps();
|
||||||
|
// vectorized
|
||||||
|
for (; i < len / 16 * 16; i += 16) {
|
||||||
|
__m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
|
||||||
|
__m512 src_ps = _mm512_cvtepi32_ps(src_epi32);
|
||||||
|
sum_ps = _mm512_add_ps(sum_ps, _mm512_mul_ps(src_ps, src_ps));
|
||||||
|
}
|
||||||
|
|
||||||
|
alignas(64) float temp[16];
|
||||||
|
_mm512_store_ps(temp, sum_ps);
|
||||||
|
for (const auto k : c10::irange(16)) {
|
||||||
|
row_sum += static_cast<float>(temp[k]);
|
||||||
|
}
|
||||||
|
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
|
||||||
|
|
||||||
// scalar
|
// scalar
|
||||||
for (; i < len; ++i) {
|
for (; i < len; ++i) {
|
||||||
|
|
@ -1239,7 +1378,7 @@ void qmaxpool_2d_nhwc_kernel(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void do_avg_pool_nhwc_on_AVX2(
|
void do_avg_pool_nhwc_on_AVX_n(
|
||||||
const typename T::underlying* i_p,
|
const typename T::underlying* i_p,
|
||||||
typename T::underlying* o_p,
|
typename T::underlying* o_p,
|
||||||
int& c_start,
|
int& c_start,
|
||||||
|
|
@ -1256,17 +1395,25 @@ void do_avg_pool_nhwc_on_AVX2(
|
||||||
int hsize,
|
int hsize,
|
||||||
int wsize,
|
int wsize,
|
||||||
int csize) {
|
int csize) {
|
||||||
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
|
||||||
// buffer for channel accumulator, used to interchange channel-loop
|
// buffer for channel accumulator, used to interchange channel-loop
|
||||||
// to inner-most, so that memory access of the input tensor data is
|
// to inner-most, so that memory access of the input tensor data is
|
||||||
// continuous.
|
// continuous.
|
||||||
|
#ifdef CPU_CAPABILITY_AVX2
|
||||||
constexpr int cb_size = 16;
|
constexpr int cb_size = 16;
|
||||||
|
#else
|
||||||
|
constexpr int cb_size = 8;
|
||||||
|
#endif
|
||||||
constexpr int vec_width = Vectorized<T>::size() / 4;
|
constexpr int vec_width = Vectorized<T>::size() / 4;
|
||||||
constexpr int cb_step = cb_size * vec_width;
|
constexpr int cb_step = cb_size * vec_width;
|
||||||
Vectorized<int32_t> acc_buffer[cb_size];
|
Vectorized<int32_t> acc_buffer[cb_size];
|
||||||
Vectorized<float> acc_buffer_fp[cb_size];
|
Vectorized<float> acc_buffer_fp[cb_size];
|
||||||
|
|
||||||
|
#ifdef CPU_CAPABILITY_AVX2
|
||||||
if (vec_width == 8) {
|
if (vec_width == 8) {
|
||||||
|
#else
|
||||||
|
if (vec_width == 16) {
|
||||||
|
#endif
|
||||||
for (int c = c_start; c < csize; c += cb_step) {
|
for (int c = c_start; c < csize; c += cb_step) {
|
||||||
int cend = std::min(cb_size, (csize - c) / vec_width);
|
int cend = std::min(cb_size, (csize - c) / vec_width);
|
||||||
// initialize loop
|
// initialize loop
|
||||||
|
|
@ -1292,14 +1439,23 @@ void do_avg_pool_nhwc_on_AVX2(
|
||||||
// convert int32 accumulative to fp32
|
// convert int32 accumulative to fp32
|
||||||
vec::convert((int*)acc_buffer, (float*)acc_buffer_fp, cend * vec_width);
|
vec::convert((int*)acc_buffer, (float*)acc_buffer_fp, cend * vec_width);
|
||||||
|
|
||||||
// first quantize using AVX using 32 lanes, then 8, finally falls
|
// first quantize using AVX2 or AVX512 using 32 lanes, then 8, finally falls
|
||||||
// back to single
|
// back to single
|
||||||
|
#ifdef CPU_CAPABILITY_AVX2
|
||||||
QuantizeAvx2<T>(
|
QuantizeAvx2<T>(
|
||||||
(float*)acc_buffer_fp,
|
(float*)acc_buffer_fp,
|
||||||
o_p + c,
|
o_p + c,
|
||||||
cend * vec_width,
|
cend * vec_width,
|
||||||
multiplier,
|
multiplier,
|
||||||
output_zero_point);
|
output_zero_point);
|
||||||
|
#else
|
||||||
|
QuantizeAvx512<T>(
|
||||||
|
(float*)acc_buffer_fp,
|
||||||
|
o_p + c,
|
||||||
|
cend * vec_width,
|
||||||
|
multiplier,
|
||||||
|
output_zero_point);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
c_start = csize / vec_width * vec_width;
|
c_start = csize / vec_width * vec_width;
|
||||||
}
|
}
|
||||||
|
|
@ -1307,7 +1463,7 @@ void do_avg_pool_nhwc_on_AVX2(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void do_avg_pool_on_AVX2(
|
void do_avg_pool_on_AVX_n(
|
||||||
typename T::underlying* i_p,
|
typename T::underlying* i_p,
|
||||||
typename T::underlying* o_p,
|
typename T::underlying* o_p,
|
||||||
int64_t& c,
|
int64_t& c,
|
||||||
|
|
@ -1326,9 +1482,13 @@ void do_avg_pool_on_AVX2(
|
||||||
int64_t stride_D,
|
int64_t stride_D,
|
||||||
int64_t stride_H,
|
int64_t stride_H,
|
||||||
int64_t stride_W) {
|
int64_t stride_W) {
|
||||||
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
|
||||||
constexpr auto vec_width = Vectorized<T>::size() / 4;
|
constexpr int vec_width = Vectorized<T>::size() / 4;
|
||||||
|
#ifdef CPU_CAPABILITY_AVX2
|
||||||
if (vec_width == 8) {
|
if (vec_width == 8) {
|
||||||
|
#else
|
||||||
|
if (vec_width == 16) {
|
||||||
|
#endif
|
||||||
for (; c + vec_width <= channel_size; c += vec_width) {
|
for (; c + vec_width <= channel_size; c += vec_width) {
|
||||||
int64_t tcntr = 0;
|
int64_t tcntr = 0;
|
||||||
|
|
||||||
|
|
@ -1416,10 +1576,10 @@ void _qadaptive_avg_pool_kernel(
|
||||||
istartH * istrideH +
|
istartH * istrideH +
|
||||||
istartW * istrideW;
|
istartW * istrideW;
|
||||||
|
|
||||||
// Note: If AVX is not available, `do_avg_pool_on_AVX2 is a noop.
|
// Note: If AVX is not available, `do_avg_pool_on_AVX_n is a noop.
|
||||||
// In that case, the following loop takes over
|
// In that case, the following loop takes over
|
||||||
// TODO: more vectorization with loop interleaving
|
// TODO: more vectorization with loop interleaving
|
||||||
do_avg_pool_on_AVX2<scalar_t>(
|
do_avg_pool_on_AVX_n<scalar_t>(
|
||||||
internal_i_p,
|
internal_i_p,
|
||||||
o_p,
|
o_p,
|
||||||
c,
|
c,
|
||||||
|
|
@ -1438,7 +1598,6 @@ void _qadaptive_avg_pool_kernel(
|
||||||
istrideD,
|
istrideD,
|
||||||
istrideH,
|
istrideH,
|
||||||
istrideW);
|
istrideW);
|
||||||
|
|
||||||
// 1) The following loop handles the remaining channels
|
// 1) The following loop handles the remaining channels
|
||||||
// 2) It also handles the Non-AVX2 path
|
// 2) It also handles the Non-AVX2 path
|
||||||
for (; c < sizeC; ++c) {
|
for (; c < sizeC; ++c) {
|
||||||
|
|
@ -1610,7 +1769,7 @@ void _qavg_pool_nhwc_kernel(
|
||||||
// For int8 quantization, we implicitly use int32 as accumulation
|
// For int8 quantization, we implicitly use int32 as accumulation
|
||||||
// Or else, it will go to the slow path
|
// Or else, it will go to the slow path
|
||||||
// TODO: support 16bit, 32bit, and etc.
|
// TODO: support 16bit, 32bit, and etc.
|
||||||
do_avg_pool_nhwc_on_AVX2<scalar_t>(
|
do_avg_pool_nhwc_on_AVX_n<scalar_t>(
|
||||||
i_p,
|
i_p,
|
||||||
o_p,
|
o_p,
|
||||||
c_start,
|
c_start,
|
||||||
|
|
@ -1744,7 +1903,7 @@ void qavg_pool3d_nhwc_kernel(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int64_t do_quantized_bilinear_on_AVX2(
|
int64_t do_quantized_bilinear_on_AVX_n(
|
||||||
const typename T::underlying*& pos1,
|
const typename T::underlying*& pos1,
|
||||||
typename T::underlying*& pos2,
|
typename T::underlying*& pos2,
|
||||||
int64_t input_height,
|
int64_t input_height,
|
||||||
|
|
@ -1762,9 +1921,13 @@ int64_t do_quantized_bilinear_on_AVX2(
|
||||||
const int64_t h1p,
|
const int64_t h1p,
|
||||||
const int64_t w1p) {
|
const int64_t w1p) {
|
||||||
int64_t c = 0;
|
int64_t c = 0;
|
||||||
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
|
||||||
constexpr auto vec_width = Vectorized<T>::size() / 4;
|
constexpr auto vec_width = Vectorized<T>::size() / 4;
|
||||||
|
#ifdef CPU_CAPABILITY_AVX2
|
||||||
if (vec_width == 8) {
|
if (vec_width == 8) {
|
||||||
|
#else
|
||||||
|
if (vec_width == 16) {
|
||||||
|
#endif
|
||||||
for (; c + vec_width <= channels; c += vec_width) {
|
for (; c + vec_width <= channels; c += vec_width) {
|
||||||
Vectorized<float> pos1_fp_v[4];
|
Vectorized<float> pos1_fp_v[4];
|
||||||
Vectorized<int32_t> pos1_int_v[4];
|
Vectorized<int32_t> pos1_int_v[4];
|
||||||
|
|
@ -1861,7 +2024,7 @@ void qupsample_bilinear2d_nhwc_kernel(
|
||||||
o_p + (h2 * output_width + w2) * channels;
|
o_p + (h2 * output_width + w2) * channels;
|
||||||
// We have to isolate this function out because the VS does not
|
// We have to isolate this function out because the VS does not
|
||||||
// expand the macro correctly.
|
// expand the macro correctly.
|
||||||
c = do_quantized_bilinear_on_AVX2<scalar_t>(
|
c = do_quantized_bilinear_on_AVX_n<scalar_t>(
|
||||||
pos1,
|
pos1,
|
||||||
pos2,
|
pos2,
|
||||||
input_height,
|
input_height,
|
||||||
|
|
@ -1989,7 +2152,7 @@ void q_batch_norm_kernel(
|
||||||
reinterpret_cast<scalar_t::underlying*>(input.data_ptr());
|
reinterpret_cast<scalar_t::underlying*>(input.data_ptr());
|
||||||
scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr());
|
scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr());
|
||||||
|
|
||||||
constexpr int kVLen = 8;
|
constexpr int kVLen = Vectorized<float>::size();
|
||||||
const int64_t outer_size = N * HxW;
|
const int64_t outer_size = N * HxW;
|
||||||
using Vec = Vectorized<scalar_t>;
|
using Vec = Vectorized<scalar_t>;
|
||||||
// Hoisted variables
|
// Hoisted variables
|
||||||
|
|
@ -2292,7 +2455,7 @@ void quantized_normalize_kernel(
|
||||||
float y_scale = Y->q_scale();
|
float y_scale = Y->q_scale();
|
||||||
float y_inv_scale = 1.0f / y_scale;
|
float y_inv_scale = 1.0f / y_scale;
|
||||||
|
|
||||||
constexpr int kFloatVLen = 8;
|
constexpr int kFloatVLen = fVec::size();
|
||||||
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
|
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
|
||||||
int64_t kNumIntVecInLayer = N / kIntVLen;
|
int64_t kNumIntVecInLayer = N / kIntVLen;
|
||||||
int64_t kNonVecRemInLayer = N % kIntVLen;
|
int64_t kNonVecRemInLayer = N % kIntVLen;
|
||||||
|
|
@ -3095,6 +3258,114 @@ void dequantize_tensor_per_tensor_affine_sub_byte_cpu(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Some quantization tests are flaky on Windows with AVX512. If --continue-through-error
|
||||||
|
// is used, only one fails. But if the failing test is skipped, another one fails.
|
||||||
|
// If the second test is also skipped, a third one fails.
|
||||||
|
// So, until Quantization support for Windows is fixed for AVX512,
|
||||||
|
// AVX2 kernels would be used instead. Ref: GH 56992.
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512) && defined(_WIN32)
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_channel_affine_stub,
|
||||||
|
dequantize_tensor_per_channel_affine_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_stub,
|
||||||
|
dequantize_tensor_per_tensor_affine_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
|
||||||
|
dequantize_tensor_per_channel_float_qparams_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(fake_quant_grad_learnable_tensor_stub,
|
||||||
|
fake_quant_learnable_grad_tensor_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(fake_quant_per_channel_cachemask_stub,
|
||||||
|
fake_quant_per_channel_cachemask_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(fake_quant_tensor_cachemask_stub,
|
||||||
|
fake_quant_tensor_cachemask_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub,
|
||||||
|
fake_quant_tensor_cachemask_tensor_qparams_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qadaptive_avg_pool2d_nhwc_stub,
|
||||||
|
qadaptive_avg_pool2d_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub,
|
||||||
|
qadaptive_avg_pool3d_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qadd_relu_stub, qbinary_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qadd_scalar_relu_stub, qadd_scalar_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qadd_scalar_stub, qadd_scalar_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qadd_stub, qbinary_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qavg_pool2d_nhwc_stub, qavg_pool2d_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qavg_pool3d_nhwc_stub, qavg_pool3d_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qbatch_norm_relu_stub, qbatch_norm_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qbatch_norm_stub, qbatch_norm_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qcat_nhwc_stub, qcat_nhwc_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qcat_relu_nhwc_stub, qcat_nhwc_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qclamp_stub, qclamp_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qclamp_min_stub, qclamp_minmax_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qclamp_max_stub, qclamp_minmax_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qelu_stub, qelu_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qhardsigmoid_stub, qhardsigmoid_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qhardswish_stub, qhardswish_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qmaxpool_2d_nhwc_stub, qmaxpool_2d_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qmul_relu_stub, qbinary_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qmul_stub, qbinary_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qrelu6_stub, qrelu_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qrelu_leaky_stub, qrelu_leaky_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qrelu_stub, qrelu_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qsigmoid_stub, qsigmoid_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qtanh_stub, qtanh_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qthreshold_stub, qthreshold_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qtopk_stub, qtopk_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(fake_quant_grad_learnable_channel_stub,
|
||||||
|
fake_quant_learnable_per_channel_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_stub,
|
||||||
|
quantize_tensor_per_tensor_affine_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_affine_stub,
|
||||||
|
quantize_tensor_per_channel_affine_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_float_qparams_stub,
|
||||||
|
quantize_tensor_per_channel_float_qparams_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(quantized_normalize_stub, qnormalize_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(qupsample_bilinear2d_nhwc_stub, qupsample_bilinear2d_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_sub_byte_stub,
|
||||||
|
quantize_tensor_per_tensor_affine_sub_byte_fn);
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_sub_byte_stub,
|
||||||
|
dequantize_tensor_per_tensor_affine_sub_byte_fn);
|
||||||
|
#else
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub,
|
REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub,
|
||||||
&dequantize_tensor_per_channel_affine_cpu);
|
&dequantize_tensor_per_channel_affine_cpu);
|
||||||
|
|
@ -3174,7 +3445,8 @@ REGISTER_DISPATCH(qthreshold_stub, &qthreshold_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel);
|
REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub, &fake_quantize_learnable_channel_grad_kernel_cpu);
|
REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub,
|
||||||
|
&fake_quantize_learnable_channel_grad_kernel_cpu);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_DISPATCH(
|
REGISTER_DISPATCH(
|
||||||
quantize_tensor_per_tensor_affine_stub,
|
quantize_tensor_per_tensor_affine_stub,
|
||||||
|
|
@ -3200,7 +3472,7 @@ REGISTER_DISPATCH(
|
||||||
REGISTER_DISPATCH(
|
REGISTER_DISPATCH(
|
||||||
dequantize_tensor_per_tensor_affine_sub_byte_stub,
|
dequantize_tensor_per_tensor_affine_sub_byte_stub,
|
||||||
&dequantize_tensor_per_tensor_affine_sub_byte_cpu);
|
&dequantize_tensor_per_tensor_affine_sub_byte_cpu);
|
||||||
|
#endif // CPU_CAPABILITY_AVX512 && _WIN32
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
|
||||||
|
|
@ -1071,13 +1071,17 @@ namespace {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
TEST(ComplexTests, TestComplexFloatImagRealConj) {
|
TEST(ComplexTests, TestComplexFloatImagRealConj) {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
float aa[] = { 1.5488e-28,2.5488e-28,3.5488e-28,4.5488e-28,5.5488e-28,6.5488e-28,7.5488e-28,8.5488e-28 };
|
float aa[] = { 1.5488e-28,2.5488e-28,3.5488e-28,4.5488e-28,5.5488e-28,6.5488e-28,7.5488e-28,8.5488e-28,
|
||||||
|
9.5488e-28,10.5488e-28,11.5488e-28,12.5488e-28,13.5488e-28,14.5488e-28,15.5488e-28,16.5488e-28};
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
float exp[] = { aa[0],0,aa[2],0,aa[4],0,aa[6],0 };
|
float exp[] = { aa[0],0,aa[2],0,aa[4],0,aa[6],0,aa[8],0,aa[10],0,aa[12],0,aa[14],0 };
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
float exp3[] = { aa[1],0,aa[3],0,aa[5],0,aa[7],0 };
|
float exp3[] = { aa[1],0,aa[3],0,aa[5],0,aa[7],0,aa[9],0,aa[11],0,aa[13],0,aa[15],0 };
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
float exp4[] = { 1.5488e-28, -2.5488e-28,3.5488e-28,-4.5488e-28,5.5488e-28,-6.5488e-28,7.5488e-28,-8.5488e-28 };
|
float exp4[] = { 1.5488e-28, -2.5488e-28,3.5488e-28,-4.5488e-28,
|
||||||
|
5.5488e-28,-6.5488e-28,7.5488e-28,-8.5488e-28,
|
||||||
|
9.5488e-28,-10.5488e-28,11.5488e-28,-12.5488e-28,
|
||||||
|
13.5488e-28,-14.5488e-28,15.5488e-28,-16.5488e-28 };
|
||||||
auto a = vcomplex::loadu(aa);
|
auto a = vcomplex::loadu(aa);
|
||||||
auto actual1 = a.real();
|
auto actual1 = a.real();
|
||||||
auto actual3 = a.imag();
|
auto actual3 = a.imag();
|
||||||
|
|
@ -1304,6 +1308,7 @@ namespace {
|
||||||
},
|
},
|
||||||
test_case);
|
test_case);
|
||||||
}
|
}
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
TYPED_TEST(FunctionalTests, Map) {
|
TYPED_TEST(FunctionalTests, Map) {
|
||||||
using vec = TypeParam;
|
using vec = TypeParam;
|
||||||
using VT = ValueType<TypeParam>;
|
using VT = ValueType<TypeParam>;
|
||||||
|
|
@ -1339,15 +1344,16 @@ namespace {
|
||||||
at::vec::map3<VT>([](vec x1, vec x2, vec x3) { return x1 + x2 + x3; }, y, x1, x2, x3, N);
|
at::vec::map3<VT>([](vec x1, vec x2, vec x3) { return x1 + x2 + x3; }, y, x1, x2, x3, N);
|
||||||
for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i]; }
|
for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i]; }
|
||||||
cmp(y, ref_y);
|
cmp(y, ref_y);
|
||||||
// test map3: y = x1 + x2 + x3 + x4
|
// test map4: y = x1 + x2 + x3 + x4
|
||||||
at::vec::map4<VT>([](vec x1, vec x2, vec x3, vec x4) { return x1 + x2 + x3 + x4; }, y, x1, x2, x3, x4, N);
|
at::vec::map4<VT>([](vec x1, vec x2, vec x3, vec x4) { return x1 + x2 + x3 + x4; }, y, x1, x2, x3, x4, N);
|
||||||
for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i] + x4[i]; }
|
for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i] + x4[i]; }
|
||||||
cmp(y, ref_y);
|
cmp(y, ref_y);
|
||||||
}
|
}
|
||||||
TYPED_TEST(FunctionalBF16Tests, Reduce) {
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
TYPED_TEST(FunctionalBF16Tests, Reduce) {
|
||||||
using vec = TypeParam;
|
using vec = TypeParam;
|
||||||
// Can't use ValueType<TypeParam> here:
|
// Can't use ValueType<TypeParam> here:
|
||||||
// Vectorized<BFloat16>::value_type returns uint16_t on AVX2
|
// Vectorized<BFloat16>::value_type returns uint16_t on AVX2/AVX512
|
||||||
using VT = c10::BFloat16;
|
using VT = c10::BFloat16;
|
||||||
using RT = float; // reference
|
using RT = float; // reference
|
||||||
constexpr auto R = 2LL; // residual
|
constexpr auto R = 2LL; // residual
|
||||||
|
|
@ -1394,7 +1400,6 @@ namespace {
|
||||||
auto y2 = at::vec::map_reduce_all<VT>([](auto x) { return x - x.exp(); }, sum, x_b1, len);
|
auto y2 = at::vec::map_reduce_all<VT>([](auto x) { return x - x.exp(); }, sum, x_b1, len);
|
||||||
ASSERT_TRUE(cmp(y1, y2)) << "Failure Details:\nTest Seed to reproduce: " << seed
|
ASSERT_TRUE(cmp(y1, y2)) << "Failure Details:\nTest Seed to reproduce: " << seed
|
||||||
<< "\nmap_reduce_all, Length: " << len << "; fp32: " << y1 << "; bf16: " << RT(y2);
|
<< "\nmap_reduce_all, Length: " << len << "; fp32: " << y1 << "; bf16: " << RT(y2);
|
||||||
|
|
||||||
}
|
}
|
||||||
// Map2ReduceAll
|
// Map2ReduceAll
|
||||||
for (int64_t len = 1; len <= N; len++) {
|
for (int64_t len = 1; len <= N; len++) {
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,13 @@
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <float.h>
|
#include <float.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
#define CACHE_LINE 64
|
||||||
|
#else
|
||||||
#define CACHE_LINE 32
|
#define CACHE_LINE 32
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__GNUC__)
|
#if defined(__GNUC__)
|
||||||
#define CACHE_ALIGN __attribute__((aligned(CACHE_LINE)))
|
#define CACHE_ALIGN __attribute__((aligned(CACHE_LINE)))
|
||||||
#define not_inline __attribute__((noinline))
|
#define not_inline __attribute__((noinline))
|
||||||
|
|
@ -26,7 +32,7 @@ CACHE_ALIGN #define
|
||||||
#endif
|
#endif
|
||||||
#if defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER)
|
#if defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER)
|
||||||
#define TEST_AGAINST_DEFAULT 1
|
#define TEST_AGAINST_DEFAULT 1
|
||||||
#elif !defined(CPU_CAPABILITY_AVX) && !defined(CPU_CAPABILITY_AVX2) && !defined(CPU_CAPABILITY_VSX)
|
#elif !defined(CPU_CAPABILITY_AVX512) && !defined(CPU_CAPABILITY_AVX2) && !defined(CPU_CAPABILITY_VSX)
|
||||||
#define TEST_AGAINST_DEFAULT 1
|
#define TEST_AGAINST_DEFAULT 1
|
||||||
#else
|
#else
|
||||||
#undef TEST_AGAINST_DEFAULT
|
#undef TEST_AGAINST_DEFAULT
|
||||||
|
|
@ -41,7 +47,8 @@ CACHE_ALIGN #define
|
||||||
return __VA_ARGS__(std::forward<decltype(args)>(args)...); \
|
return __VA_ARGS__(std::forward<decltype(args)>(args)...); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) && (defined(__GNUC__) || defined(__GNUG__))
|
#if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) || \
|
||||||
|
defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__))
|
||||||
#undef CHECK_DEQUANT_WITH_LOW_PRECISION
|
#undef CHECK_DEQUANT_WITH_LOW_PRECISION
|
||||||
#define CHECK_WITH_FMA 1
|
#define CHECK_WITH_FMA 1
|
||||||
#elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2)
|
#elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2)
|
||||||
|
|
|
||||||
|
|
@ -722,44 +722,43 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||||
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS})
|
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# NOTE [ Linking AVX and non-AVX files ]
|
# NOTE [ Linking AVX-n and non-AVX-n files ]
|
||||||
#
|
#
|
||||||
# Regardless of the CPU capabilities, we build some files with AVX and AVX2
|
# Regardless of the CPU capabilities, we build some files with AVX2, and AVX512
|
||||||
# instruction set. If the host CPU doesn't support those, we simply ignore their
|
# instruction set. If the host CPU doesn't support those, we simply ignore their
|
||||||
# functions at runtime during dispatch.
|
# functions at runtime during dispatch.
|
||||||
#
|
#
|
||||||
# We must make sure that those files are at the end of the input list when
|
# We must make sure that those files are at the end of the input list when
|
||||||
# linking the torch_cpu library. Otherwise, the following error scenario might
|
# linking the torch_cpu library. Otherwise, the following error scenario might
|
||||||
# occur:
|
# occur:
|
||||||
# 1. A non-AVX and an AVX file both call a function defined with the `inline`
|
# 1. A non-AVX2 and an AVX2 file both call a function defined with the `inline`
|
||||||
# keyword
|
# keyword
|
||||||
# 2. The compiler decides not to inline this function
|
# 2. The compiler decides not to inline this function
|
||||||
# 3. Two different versions of the machine code are generated for this function:
|
# 3. Two different versions of the machine code are generated for this function:
|
||||||
# one without AVX instructions and one with AVX.
|
# one without AVX2 instructions and one with AVX2.
|
||||||
# 4. When linking, the AVX version is found earlier in the input object files,
|
# 4. When linking, the AVX2 version is found earlier in the input object files,
|
||||||
# so the linker makes the entire library use it, even in code not guarded by
|
# so the linker makes the entire library use it, even in code not guarded by
|
||||||
# the dispatcher.
|
# the dispatcher.
|
||||||
# 5. A CPU without AVX support executes this function, encounters an AVX
|
# 5. A CPU without AVX2 support executes this function, encounters an AVX2
|
||||||
# instruction and crashes.
|
# instruction and crashes.
|
||||||
#
|
#
|
||||||
# Thus we organize the input files in the following order:
|
# Thus we organize the input files in the following order:
|
||||||
# 1. All files with no AVX support
|
# 1. All files with no AVX-n support
|
||||||
# 2. All files with AVX support (conveniently, they all have names ending with
|
# 2. All files with AVX2 support ('*AVX2.cpp')
|
||||||
# 'AVX.cpp')
|
# 3. All files with AVX512 support ('*AVX512.cpp')
|
||||||
# 3. All files with AVX2 support ('*AVX2.cpp')
|
|
||||||
set(Caffe2_CPU_SRCS_NON_AVX)
|
set(Caffe2_CPU_SRCS_NON_AVX)
|
||||||
set(Caffe2_CPU_SRCS_AVX)
|
|
||||||
set(Caffe2_CPU_SRCS_AVX2)
|
set(Caffe2_CPU_SRCS_AVX2)
|
||||||
|
set(Caffe2_CPU_SRCS_AVX512)
|
||||||
foreach(input_filename ${Caffe2_CPU_SRCS})
|
foreach(input_filename ${Caffe2_CPU_SRCS})
|
||||||
if(${input_filename} MATCHES "AVX\\.cpp")
|
if(${input_filename} MATCHES "AVX2\\.cpp")
|
||||||
list(APPEND Caffe2_CPU_SRCS_AVX ${input_filename})
|
|
||||||
elseif(${input_filename} MATCHES "AVX2\\.cpp")
|
|
||||||
list(APPEND Caffe2_CPU_SRCS_AVX2 ${input_filename})
|
list(APPEND Caffe2_CPU_SRCS_AVX2 ${input_filename})
|
||||||
|
elseif(${input_filename} MATCHES "AVX512\\.cpp")
|
||||||
|
list(APPEND Caffe2_CPU_SRCS_AVX512 ${input_filename})
|
||||||
else()
|
else()
|
||||||
list(APPEND Caffe2_CPU_SRCS_NON_AVX ${input_filename})
|
list(APPEND Caffe2_CPU_SRCS_NON_AVX ${input_filename})
|
||||||
endif()
|
endif()
|
||||||
endforeach(input_filename)
|
endforeach(input_filename)
|
||||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS_NON_AVX} ${Caffe2_CPU_SRCS_AVX} ${Caffe2_CPU_SRCS_AVX2})
|
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS_NON_AVX} ${Caffe2_CPU_SRCS_AVX2} ${Caffe2_CPU_SRCS_AVX512})
|
||||||
|
|
||||||
# ==========================================================
|
# ==========================================================
|
||||||
# END formerly-libtorch sources
|
# END formerly-libtorch sources
|
||||||
|
|
|
||||||
|
|
@ -63,14 +63,6 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||||
endif()
|
endif()
|
||||||
endif(MSVC)
|
endif(MSVC)
|
||||||
|
|
||||||
if(C_AVX_FOUND)
|
|
||||||
if(MSVC)
|
|
||||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "${OPT_FLAG}/arch:AVX ${CXX_AVX_FLAGS}")
|
|
||||||
else(MSVC)
|
|
||||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "${OPT_FLAG} ${CXX_AVX_FLAGS}")
|
|
||||||
endif(MSVC)
|
|
||||||
endif(C_AVX_FOUND)
|
|
||||||
|
|
||||||
if(NOT MSVC AND NOT "${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
|
if(NOT MSVC AND NOT "${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
|
||||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/MapAllocator.cpp PROPERTIES COMPILE_FLAGS "-fno-openmp")
|
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/MapAllocator.cpp PROPERTIES COMPILE_FLAGS "-fno-openmp")
|
||||||
endif()
|
endif()
|
||||||
|
|
@ -80,15 +72,16 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||||
list(APPEND CPU_CAPABILITY_NAMES "DEFAULT")
|
list(APPEND CPU_CAPABILITY_NAMES "DEFAULT")
|
||||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}")
|
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}")
|
||||||
|
|
||||||
if(CXX_AVX_FOUND)
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX_CPU_DEFINITION")
|
if(CXX_AVX512_FOUND)
|
||||||
list(APPEND CPU_CAPABILITY_NAMES "AVX")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX512_CPU_DEFINITION")
|
||||||
|
list(APPEND CPU_CAPABILITY_NAMES "AVX512")
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX")
|
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512")
|
||||||
else(MSVC)
|
else(MSVC)
|
||||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx")
|
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx512f -mavx512bw -mavx512vl -mavx512dq -mfma")
|
||||||
endif(MSVC)
|
endif(MSVC)
|
||||||
endif(CXX_AVX_FOUND)
|
endif(CXX_AVX512_FOUND)
|
||||||
|
|
||||||
if(CXX_AVX2_FOUND)
|
if(CXX_AVX2_FOUND)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX2_CPU_DEFINITION")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX2_CPU_DEFINITION")
|
||||||
|
|
@ -103,11 +96,24 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||||
endif(COMPILER_SUPPORTS_NO_AVX256_SPLIT)
|
endif(COMPILER_SUPPORTS_NO_AVX256_SPLIT)
|
||||||
|
|
||||||
list(APPEND CPU_CAPABILITY_NAMES "AVX2")
|
list(APPEND CPU_CAPABILITY_NAMES "AVX2")
|
||||||
if(MSVC)
|
if(DEFINED ENV{ATEN_AVX512_256})
|
||||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX2")
|
if($ENV{ATEN_AVX512_256} MATCHES "TRUE")
|
||||||
else(MSVC)
|
if(CXX_AVX512_FOUND)
|
||||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx2 -mfma ${CPU_NO_AVX256_SPLIT_FLAGS}")
|
message("-- ATen AVX2 kernels will use 32 ymm registers")
|
||||||
endif(MSVC)
|
if(MSVC)
|
||||||
|
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512")
|
||||||
|
else(MSVC)
|
||||||
|
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=native ${CPU_NO_AVX256_SPLIT_FLAGS}")
|
||||||
|
endif(MSVC)
|
||||||
|
endif(CXX_AVX512_FOUND)
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
if(MSVC)
|
||||||
|
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX2")
|
||||||
|
else(MSVC)
|
||||||
|
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx2 -mfma ${CPU_NO_AVX256_SPLIT_FLAGS}")
|
||||||
|
endif(MSVC)
|
||||||
|
endif()
|
||||||
endif(CXX_AVX2_FOUND)
|
endif(CXX_AVX2_FOUND)
|
||||||
|
|
||||||
if(CXX_VSX_FOUND)
|
if(CXX_VSX_FOUND)
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,25 @@ SET(AVX_CODE "
|
||||||
}
|
}
|
||||||
")
|
")
|
||||||
|
|
||||||
|
SET(AVX512_CODE "
|
||||||
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
__m512i b = a;
|
||||||
|
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
")
|
||||||
|
|
||||||
SET(AVX2_CODE "
|
SET(AVX2_CODE "
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
|
@ -56,6 +75,8 @@ ENDMACRO()
|
||||||
|
|
||||||
CHECK_SSE(C "AVX" " ;-mavx;/arch:AVX")
|
CHECK_SSE(C "AVX" " ;-mavx;/arch:AVX")
|
||||||
CHECK_SSE(C "AVX2" " ;-mavx2 -mfma;/arch:AVX2")
|
CHECK_SSE(C "AVX2" " ;-mavx2 -mfma;/arch:AVX2")
|
||||||
|
CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")
|
||||||
|
|
||||||
CHECK_SSE(CXX "AVX" " ;-mavx;/arch:AVX")
|
CHECK_SSE(CXX "AVX" " ;-mavx;/arch:AVX")
|
||||||
CHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma;/arch:AVX2")
|
CHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma;/arch:AVX2")
|
||||||
|
CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")
|
||||||
|
|
|
||||||
7
setup.py
7
setup.py
|
|
@ -99,6 +99,12 @@
|
||||||
# BUILD_BINARY
|
# BUILD_BINARY
|
||||||
# enables the additional binaries/ build
|
# enables the additional binaries/ build
|
||||||
#
|
#
|
||||||
|
# ATEN_AVX512_256=TRUE
|
||||||
|
# ATen AVX2 kernels can use 32 ymm registers, instead of the default 16.
|
||||||
|
# This option can be used if AVX512 doesn't perform well on a machine.
|
||||||
|
# The FBGEMM library also uses AVX512_256 kernels on Xeon D processors,
|
||||||
|
# but it also has some (optimized) assembly code.
|
||||||
|
#
|
||||||
# PYTORCH_BUILD_VERSION
|
# PYTORCH_BUILD_VERSION
|
||||||
# PYTORCH_BUILD_NUMBER
|
# PYTORCH_BUILD_NUMBER
|
||||||
# specify the version of PyTorch, rather than the hard-coded version
|
# specify the version of PyTorch, rather than the hard-coded version
|
||||||
|
|
@ -928,6 +934,7 @@ if __name__ == '__main__':
|
||||||
'include/ATen/*.h',
|
'include/ATen/*.h',
|
||||||
'include/ATen/cpu/*.h',
|
'include/ATen/cpu/*.h',
|
||||||
'include/ATen/cpu/vec/vec256/*.h',
|
'include/ATen/cpu/vec/vec256/*.h',
|
||||||
|
'include/ATen/cpu/vec/vec512/*.h',
|
||||||
'include/ATen/cpu/vec/*.h',
|
'include/ATen/cpu/vec/*.h',
|
||||||
'include/ATen/core/*.h',
|
'include/ATen/core/*.h',
|
||||||
'include/ATen/cuda/*.cuh',
|
'include/ATen/cuda/*.cuh',
|
||||||
|
|
|
||||||
|
|
@ -29,19 +29,19 @@ TEST_F(DispatchTest, TestAVX2) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DispatchTest, TestAVX) {
|
TEST_F(DispatchTest, TestAVX512) {
|
||||||
const std::vector<int> ints {1, 2, 3, 4};
|
const std::vector<int> ints {1, 2, 3, 4};
|
||||||
const std::vector<int> result {1, 4, 27, 256};
|
const std::vector<int> result {1, 4, 27, 256};
|
||||||
const auto vals_tensor = torch::tensor(ints);
|
const auto vals_tensor = torch::tensor(ints);
|
||||||
const auto pows_tensor = torch::tensor(ints);
|
const auto pows_tensor = torch::tensor(ints);
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
_putenv("ATEN_CPU_CAPABILITY=avx");
|
_putenv("ATEN_CPU_CAPABILITY=avx512");
|
||||||
#else
|
#else
|
||||||
setenv("ATEN_CPU_CAPABILITY", "avx", 1);
|
setenv("ATEN_CPU_CAPABILITY", "avx512", 1);
|
||||||
#endif
|
#endif
|
||||||
const auto actual_pow_avx = vals_tensor.pow(pows_tensor);
|
const auto actual_pow_avx512 = vals_tensor.pow(pows_tensor);
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
ASSERT_EQ(result[i], actual_pow_avx[i].item<int>());
|
ASSERT_EQ(result[i], actual_pow_avx512[i].item<int>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
# torch
|
# torch
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -11,7 +11,7 @@ import torch.nn.quantized.dynamic as nnqd
|
||||||
import torch.nn.intrinsic.quantized as nniq
|
import torch.nn.intrinsic.quantized as nniq
|
||||||
|
|
||||||
# Testing utils
|
# Testing utils
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import TestCase, IS_AVX512_VNNI_SUPPORTED
|
||||||
from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm
|
from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm
|
||||||
|
|
||||||
def remove_prefix(text, prefix):
|
def remove_prefix(text, prefix):
|
||||||
|
|
@ -238,6 +238,7 @@ class TestSerialization(TestCase):
|
||||||
# TODO: graph mode quantized conv3d module
|
# TODO: graph mode quantized conv3d module
|
||||||
|
|
||||||
@override_qengines
|
@override_qengines
|
||||||
|
@unittest.skipIf(IS_AVX512_VNNI_SUPPORTED, "This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098")
|
||||||
def test_lstm(self):
|
def test_lstm(self):
|
||||||
class LSTMModule(torch.nn.Module):
|
class LSTMModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
||||||
|
|
@ -339,6 +339,15 @@ IS_WINDOWS = sys.platform == "win32"
|
||||||
IS_MACOS = sys.platform == "darwin"
|
IS_MACOS = sys.platform == "darwin"
|
||||||
IS_PPC = platform.machine() == "ppc64le"
|
IS_PPC = platform.machine() == "ppc64le"
|
||||||
|
|
||||||
|
def is_avx512_vnni_supported():
|
||||||
|
if sys.platform != 'linux':
|
||||||
|
return False
|
||||||
|
with open("/proc/cpuinfo", encoding="ascii") as f:
|
||||||
|
lines = f.read()
|
||||||
|
return "avx512vnni" in lines
|
||||||
|
|
||||||
|
IS_AVX512_VNNI_SUPPORTED = is_avx512_vnni_supported()
|
||||||
|
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def TemporaryFileName(*args, **kwargs):
|
def TemporaryFileName(*args, **kwargs):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user