diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index 5d15b80fcb8..51d5f2d6412 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -105,11 +105,6 @@ std::string get_cpu_capability() { return "DEFAULT"; case native::CPUCapability::ZVECTOR: return "Z VECTOR"; -#elif defined(HAVE_SVE256_BF16_CPU_DEFINITION) - case native::CPUCapability::DEFAULT: - return "DEFAULT"; - case native::CPUCapability::SVE256_BF16: - return "SVE256_BF16"; #elif defined(HAVE_SVE_CPU_DEFINITION) case native::CPUCapability::DEFAULT: return "DEFAULT"; diff --git a/aten/src/ATen/cpu/vec/sve/sve_helper.h b/aten/src/ATen/cpu/vec/sve/sve_helper.h index 1aa37887bd5..e511ebb52b2 100644 --- a/aten/src/ATen/cpu/vec/sve/sve_helper.h +++ b/aten/src/ATen/cpu/vec/sve/sve_helper.h @@ -17,7 +17,6 @@ typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); -typedef svbfloat16_t vls_bfloat16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); @@ -42,7 +41,6 @@ typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDT #define ONE_U32 svdup_n_u32(1) #define ONE_U64 svdup_n_u64(1) #define ONE_F16 svdup_n_f16(1.f) -#define ONE_BF16 svdup_n_bf16(1.f) #define ONE_F32 svdup_n_f32(1.f) #define ONE_F64 svdup_n_f64(1.0) #define ALL_S8_TRUE_MASK svdup_n_s8(0xff) @@ -57,8 +55,6 @@ typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDT #define ALL_U8_FALSE_MASK svdup_n_u8(0x00) #define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK) #define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK) -#define ALL_BF16_TRUE_MASK svreinterpret_bf16_s16(ALL_S16_TRUE_MASK) -#define ALL_BF16_FALSE_MASK svreinterpret_bf16_s16(ALL_S16_FALSE_MASK) #define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK) #define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK) #define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK) diff --git a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h deleted file mode 100644 index 8f06fe12636..00000000000 --- a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h +++ /dev/null @@ -1,524 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -namespace at { -namespace vec { -// Note [CPU_CAPABILITY namespace] -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// This header, and all of its subheaders, will be compiled with -// different architecture flags for each supported set of vector -// intrinsics. So we need to make sure they aren't inadvertently -// linked together. We do this by declaring objects in an `inline -// namespace` which changes the name mangling, but can still be -// accessed as `at::vec`. -inline namespace CPU_CAPABILITY { - -#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) - -template <> -class Vectorized { -private: - vls_bfloat16_t values; - -public: - using value_type = BFloat16; - using size_type = int; - - static constexpr size_type size() { - return VECTOR_WIDTH / sizeof(BFloat16); - } - - Vectorized() {} - Vectorized(svbfloat16_t v) : values(v) {} - Vectorized(int val); - Vectorized(BFloat16 val); - - template < - typename... Args, - typename = std::enable_if_t<(sizeof...(Args) == size())>> - Vectorized(Args... vals) { - __at_align__ BFloat16 buffer[size()] = {vals...}; - values = svld1_bf16(ptrue, reinterpret_cast(buffer)); - } - - operator svbfloat16_t() const { - return values; - } - static Vectorized blendv(const Vectorized& a, const - Vectorized& b, const Vectorized& mask_) { - svbool_t mask = svcmpeq_s16(ptrue, svreinterpret_s16_bf16(mask_), - ALL_S16_TRUE_MASK); - return svsel_bf16(mask, b, a); - } - template - static Vectorized arange(BFloat16 base = 0.f, step_t step = - static_cast(1)) { - __at_align__ BFloat16 buffer[size()]; - for (int64_t i = 0; i < size(); i++) { - buffer[i] = base + i * step; - } - return svld1_bf16(ptrue, reinterpret_cast(buffer)); - } - static Vectorized set(const Vectorized& a, const - Vectorized& b, int64_t count = size()) { - if (count == 0) { - return a; - } else if (count < size()) { - return svsel_bf16(svwhilelt_b16(0ull, count), b, a); - } - return b; - } - static Vectorized loadu(const void* ptr, int64_t count = size()) { - if (count == size()) - return svld1_bf16(ptrue, reinterpret_cast(ptr)); - svbool_t pg = svwhilelt_b16(0ull, count); - return svld1_bf16(pg, reinterpret_cast(ptr)); - } - void store(void* ptr, int64_t count = size()) const { - __at_align__ bfloat16_t tmp[size()]; - std::memset(tmp, 0, sizeof(tmp)); - if (count == size()) { - svst1_bf16(ptrue, reinterpret_cast(tmp), values); - } else { - svbool_t pg = svwhilelt_b16(0ull, count); - svst1_bf16(pg, reinterpret_cast(tmp), values); - } - std::memcpy( - reinterpret_cast(ptr), - reinterpret_cast(tmp), - count * sizeof(bfloat16_t)); - } - const BFloat16& operator[](int idx) const = delete; - BFloat16& operator[](int idx) = delete; - int64_t zero_mask() const { - int64_t mask = 0; - // returns an integer mask where all zero elements are translated to - // 1-bit and others are translated to 0-bit int64_t mask = 0; - __at_align__ int16_t mask_array[size()]; - - svbool_t svbool_mask = svcmpeq_f16(ptrue, svreinterpret_f16_bf16(values), ZERO_F16); - svst1_s16(ptrue, mask_array, svsel_s16(svbool_mask, - ALL_S16_TRUE_MASK, - ALL_S16_FALSE_MASK)); - for (int64_t i = 0; i < size(); ++i) { - if (mask_array[i]) mask |= (1ull << i); - } - return mask; - } - Vectorized isnan() const; - bool has_inf_nan() const; - Vectorized map(BFloat16 (*f)(BFloat16)) const { - __at_align__ BFloat16 tmp[size()]; - store(tmp); - for (int64_t i = 0; i < size(); ++i) { - tmp[i] = f(tmp[i]); - } - return loadu(tmp); - } - Vectorized abs() const { - auto mask = svdup_n_u16(0x7FFF); - auto vals = svreinterpret_u16_bf16(values); - vals = svand_u16_x(ptrue, vals, mask); - return svreinterpret_bf16_u16(vals); - } - Vectorized angle() const; - Vectorized real() const { - return values; - } - Vectorized imag() const { - return Vectorized(0.f); - } - Vectorized conj() const { - return values; - } - Vectorized acos() const; - Vectorized acosh() const; - Vectorized asin() const; - Vectorized atan() const; - Vectorized atanh() const; - Vectorized atan2(const Vectorized &b) const; - Vectorized copysign(const Vectorized &sign) const; - Vectorized erf() const; - Vectorized erfc() const; - Vectorized erfinv() const; - Vectorized exp() const; - Vectorized exp2() const; - Vectorized expm1() const; - Vectorized exp_u20() const { - return exp(); - } - Vectorized fmod(const Vectorized& q) const; - Vectorized hypot(const Vectorized &b) const; - Vectorized i0() const; - Vectorized i0e() const; - Vectorized digamma() const; - Vectorized igamma(const Vectorized &x) const; - Vectorized igammac(const Vectorized &x) const; - Vectorized nextafter(const Vectorized &b) const; - Vectorized log() const; - Vectorized log2() const; - Vectorized log10() const; - Vectorized log1p() const; - Vectorized frac() const; - Vectorized sin() const; - Vectorized sinh() const; - Vectorized cos() const; - Vectorized cosh() const; - Vectorized ceil() const; - Vectorized floor() const; - Vectorized neg() const { - auto mask = svdup_n_u16(0x8000); - auto vals = svreinterpret_u16_bf16(values); - vals = sveor_u16_x(ptrue, vals, mask); - return svreinterpret_bf16_u16(vals); - }; - Vectorized round() const; - Vectorized tan() const; - Vectorized tanh() const; - Vectorized trunc() const; - Vectorized lgamma() const; - Vectorized sqrt() const; - Vectorized reciprocal() const; - Vectorized rsqrt() const; - Vectorized pow(const Vectorized &b) const; - // Comparison using the _CMP_**_OQ predicate. - // `O`: get false if an operand is NaN - // `Q`: do not raise if an operand is NaN - Vectorized operator==(const Vectorized& other) const; - - Vectorized operator!=(const Vectorized& other) const; - - Vectorized operator<(const Vectorized& other) const; - - Vectorized operator<=(const Vectorized& other) const; - - Vectorized operator>(const Vectorized& other) const; - - Vectorized operator>=(const Vectorized& other) const; - - Vectorized eq(const Vectorized& other) const; - Vectorized ne(const Vectorized& other) const; - Vectorized gt(const Vectorized& other) const; - Vectorized ge(const Vectorized& other) const; - Vectorized lt(const Vectorized& other) const; - Vectorized le(const Vectorized& other) const; -}; - -inline std::tuple, Vectorized> convert_bfloat16_float( - const Vectorized& a) { - static_assert( - Vectorized::size() == 2 * Vectorized::size()); - auto zero = svreinterpret_bf16_f32(svdup_n_f32(0.0f)); - auto bf16_vec1 = svzip1_bf16(zero, a); - auto bf16_vec2 = svzip2_bf16(zero, a); - auto x1 = svreinterpret_f32_bf16(bf16_vec1); - auto x2 = svreinterpret_f32_bf16(bf16_vec2); - return {Vectorized(x1), Vectorized(x2)}; -} - -inline Vectorized convert_float_bfloat16( - const Vectorized& a, - const Vectorized& b) { - static_assert( - Vectorized::size() == 2 * Vectorized::size()); - svbfloat16_t x1 = svcvt_bf16_f32_z(ptrue, a); - svbfloat16_t x2 = svcvt_bf16_f32_z(ptrue, b); - return Vectorized(svuzp1_bf16(x1, x2)); -} - -inline void load_fp32_from_bf16(const BFloat16* data, Vectorized& out) { - __at_align__ float values[Vectorized::size()]; - for (const auto k : c10::irange(Vectorized::size())) { - values[k] = data[k]; - } - out = Vectorized::loadu(values); -} - -inline void load_fp32_from_bf16( - const BFloat16* data, - Vectorized& out1, - Vectorized& out2) { - Vectorized bf16_vec = Vectorized::loadu(data); - auto floats = convert_bfloat16_float(bf16_vec); - out1 = std::get<0>(floats); - out2 = std::get<1>(floats); -} - -template -Vectorized binary_operator_via_float( - Op op, - const Vectorized& a, - const Vectorized& b) { - const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); - const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); - return convert_float_bfloat16( - op(a_float_low, b_float_low), op(a_float_high, b_float_high)); -} - -template <> -Vectorized inline operator+( - const Vectorized& a, - const Vectorized& b) { - return binary_operator_via_float(std::plus>(), a, b); -} - -template <> -Vectorized inline operator-( - const Vectorized& a, - const Vectorized& b) { - return binary_operator_via_float(std::minus>(), a, b); -} - -template <> -Vectorized inline operator*( - const Vectorized& a, - const Vectorized& b) { - return binary_operator_via_float(std::multiplies>(), a, b); -} - -template <> -Vectorized inline operator/( - const Vectorized& a, - const Vectorized& b) { - return binary_operator_via_float(std::divides>(), a, b); -} - -inline Vectorized::Vectorized(int val) { - auto vals_f = svdup_n_f32(val); - values = convert_float_bfloat16(vals_f, vals_f); -} - - -inline Vectorized::Vectorized(BFloat16 val) { - auto vals_f = svdup_n_f32((float) val); - values = convert_float_bfloat16(vals_f, vals_f); -} - -bool inline Vectorized::has_inf_nan() const { - auto [v1, v2] = convert_bfloat16_float(values); - return v1.has_inf_nan() || v2.has_inf_nan(); -} -// frac. Implement this here so we can use subtraction -Vectorized inline Vectorized::frac() const { - return *this - this->trunc(); -} - -#define DEFINE_BF16_FUNC_VIA_FLOAT(func_name) \ - Vectorized inline Vectorized::func_name() const { \ - auto [v1, v2] = convert_bfloat16_float(*this); \ - v1 = v1.func_name(); \ - v2 = v2.func_name(); \ - return convert_float_bfloat16(v1, v2); \ - } - -#define DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(func_name) \ - Vectorized inline Vectorized::func_name(const Vectorized &a) const { \ - auto [v1, v2] = convert_bfloat16_float(*this); \ - auto [v3, v4] = convert_bfloat16_float(a); \ - v1 = v1.func_name(v3); \ - v2 = v2.func_name(v4); \ - return convert_float_bfloat16(v1, v2); \ - } - -DEFINE_BF16_FUNC_VIA_FLOAT(isnan); -DEFINE_BF16_FUNC_VIA_FLOAT(angle); -DEFINE_BF16_FUNC_VIA_FLOAT(acos); -DEFINE_BF16_FUNC_VIA_FLOAT(acosh); -DEFINE_BF16_FUNC_VIA_FLOAT(asin); -DEFINE_BF16_FUNC_VIA_FLOAT(atan); -DEFINE_BF16_FUNC_VIA_FLOAT(atanh); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign); -DEFINE_BF16_FUNC_VIA_FLOAT(erf); -DEFINE_BF16_FUNC_VIA_FLOAT(erfc); -DEFINE_BF16_FUNC_VIA_FLOAT(exp); -DEFINE_BF16_FUNC_VIA_FLOAT(exp2); -DEFINE_BF16_FUNC_VIA_FLOAT(expm1); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot); -DEFINE_BF16_FUNC_VIA_FLOAT(i0); -DEFINE_BF16_FUNC_VIA_FLOAT(i0e); -DEFINE_BF16_FUNC_VIA_FLOAT(digamma); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter); -DEFINE_BF16_FUNC_VIA_FLOAT(log); -DEFINE_BF16_FUNC_VIA_FLOAT(log2); -DEFINE_BF16_FUNC_VIA_FLOAT(log10); -DEFINE_BF16_FUNC_VIA_FLOAT(log1p); -DEFINE_BF16_FUNC_VIA_FLOAT(sin); -DEFINE_BF16_FUNC_VIA_FLOAT(sinh); -DEFINE_BF16_FUNC_VIA_FLOAT(cos); -DEFINE_BF16_FUNC_VIA_FLOAT(cosh); -DEFINE_BF16_FUNC_VIA_FLOAT(ceil); -DEFINE_BF16_FUNC_VIA_FLOAT(floor); -DEFINE_BF16_FUNC_VIA_FLOAT(round); -DEFINE_BF16_FUNC_VIA_FLOAT(tan); -DEFINE_BF16_FUNC_VIA_FLOAT(tanh); -DEFINE_BF16_FUNC_VIA_FLOAT(trunc); -DEFINE_BF16_FUNC_VIA_FLOAT(lgamma); -DEFINE_BF16_FUNC_VIA_FLOAT(sqrt); -DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal); -DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow); - -Vectorized inline Vectorized::operator==(const Vectorized& other) const { - auto [f1, f2] = convert_bfloat16_float(values); - auto [f3, f4] = convert_bfloat16_float(other); - svbool_t mask1 = svcmpeq_f32(ptrue, f1, f3); - svbool_t mask2 = svcmpeq_f32(ptrue, f2, f4); - auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); - auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); - - auto bf16_1 = svreinterpret_bf16_f32(res1); - auto bf16_2 = svreinterpret_bf16_f32(res2); - return svuzp1_bf16(bf16_1, bf16_2); -} -Vectorized inline Vectorized::operator!=(const Vectorized& other) const { - auto [f1, f2] = convert_bfloat16_float(values); - auto [f3, f4] = convert_bfloat16_float(other); - svbool_t mask1 = svcmpne_f32(ptrue, f1, f3); - svbool_t mask2 = svcmpne_f32(ptrue, f2, f4); - auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); - auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); - - auto bf16_1 = svreinterpret_bf16_f32(res1); - auto bf16_2 = svreinterpret_bf16_f32(res2); - return svuzp1_bf16(bf16_1, bf16_2); -} -Vectorized inline Vectorized::operator>(const Vectorized& other) const { - auto [v1, v2] = convert_bfloat16_float(*this); - auto [v3, v4] = convert_bfloat16_float(other); - return convert_float_bfloat16(v1 > v3, v2 > v4); -} -Vectorized inline Vectorized::operator>=(const Vectorized& other) const { - auto [v1, v2] = convert_bfloat16_float(*this); - auto [v3, v4] = convert_bfloat16_float(other); - return convert_float_bfloat16(v1 >= v3, v2 >= v4); -} -Vectorized inline Vectorized::operator<(const Vectorized& other) const { - auto [v1, v2] = convert_bfloat16_float(*this); - auto [v3, v4] = convert_bfloat16_float(other); - return convert_float_bfloat16(v1 < v3, v2 < v4); -} -Vectorized inline Vectorized::operator<=(const Vectorized& other) const { - auto [v1, v2] = convert_bfloat16_float(*this); - auto [v3, v4] = convert_bfloat16_float(other); - return convert_float_bfloat16(v1 <= v3, v2 <= v4); -} - -// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if -// either input is a NaN. -template <> -Vectorized inline maximum(const Vectorized& a, const -Vectorized& b) { - return binary_operator_via_float(static_cast(*)(const Vectorized&, const Vectorized&)>(&maximum), a, b); -} - -// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if -// either input is a NaN. -template <> -Vectorized inline minimum(const Vectorized& a, const -Vectorized& b) { - return binary_operator_via_float(static_cast(*)(const Vectorized&, const Vectorized&)>(&minimum), a, b); -} - -template <> -Vectorized inline clamp_max(const Vectorized& a, const -Vectorized& max) { - return binary_operator_via_float(static_cast(*)(const Vectorized&, const Vectorized&)>(&clamp_max), a, max); -} - -template <> -Vectorized inline clamp_min(const Vectorized& a, const -Vectorized& min) { - return binary_operator_via_float(static_cast(*)(const Vectorized&, const Vectorized&)>(&clamp_min), a, min); -} - -template <> -Vectorized inline clamp(const Vectorized& a, const -Vectorized& min, const Vectorized& max) { - return clamp_min(clamp_max(a, max), min); -} - -template <> -Vectorized inline operator&(const Vectorized& a, const -Vectorized& b) { - return svreinterpret_bf16_u16(svand_u16_x(ptrue, svreinterpret_u16_bf16(a), - svreinterpret_u16_bf16(b))); -} - -template <> -Vectorized inline operator|(const Vectorized& a, const -Vectorized& b) { - return svreinterpret_bf16_u16(svorr_u16_x(ptrue, svreinterpret_u16_bf16(a), - svreinterpret_u16_bf16(b))); -} - -template <> -Vectorized inline operator^(const Vectorized& a, const -Vectorized& b) { - return svreinterpret_bf16_u16(sveor_u16_x(ptrue, svreinterpret_u16_bf16(a), - svreinterpret_u16_bf16(b))); -} - -Vectorized inline Vectorized::eq(const Vectorized& -other) const { - return (*this == other) & Vectorized(1.0f); -} - -Vectorized inline Vectorized::ne(const Vectorized& -other) const { - return (*this != other) & Vectorized(1.0f); -} - -Vectorized inline Vectorized::gt(const Vectorized& -other) const { - return (*this > other) & Vectorized(1.0f); -} - -Vectorized inline Vectorized::ge(const Vectorized& -other) const { - return (*this >= other) & Vectorized(1.0f); -} - -Vectorized inline Vectorized::lt(const Vectorized& -other) const { - return (*this < other) & Vectorized(1.0f); -} - -Vectorized inline Vectorized::le(const Vectorized& -other) const { - return (*this <= other) & Vectorized(1.0f); -} - -template <> -inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { - const int64_t fraction = n % Vectorized::size(); -#pragma unroll - for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { - svst1_bf16(ptrue, const_cast(reinterpret_cast(dst)) + i, svldnt1_bf16(ptrue, const_cast(reinterpret_cast(src)) + i)); - } -#pragma unroll - for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { - svbool_t pg = svwhilelt_b16(i, n); - svst1_bf16(pg, const_cast(reinterpret_cast(dst)) + i, svldnt1_bf16(pg, const_cast(reinterpret_cast(src)) + i)); - } -} - -template <> -Vectorized inline fmadd(const Vectorized& a, const -Vectorized& b, const Vectorized& c) { - return a * b + c; -} - -#endif // defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16) - -} // namespace CPU_CAPABILITY -} // namespace vec -} // namespace at diff --git a/aten/src/ATen/cpu/vec/sve/vec_common_sve.h b/aten/src/ATen/cpu/vec/sve/vec_common_sve.h index 2f65c3ec4ad..c7968e271f9 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_common_sve.h +++ b/aten/src/ATen/cpu/vec/sve/vec_common_sve.h @@ -13,7 +13,6 @@ #include #include #include -#include #endif @@ -31,29 +30,33 @@ inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_SVE) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -#define DEFINE_SVE_CAST(t1_t, t1_prefix, t2_t, t2_prefix) \ -template<> \ -inline Vectorized cast(const Vectorized& src) { \ - return svreinterpret_##t1_prefix##_##t2_prefix(src); \ -} \ -template<> \ -inline Vectorized cast(const Vectorized& src) { \ - return svreinterpret_##t2_prefix##_##t1_prefix(src); \ + +template<> +inline Vectorized cast(const Vectorized& src) { + return svreinterpret_f32_f64(src); } -DEFINE_SVE_CAST(int64_t, s64, double, f64) -DEFINE_SVE_CAST(int32_t, s32, double, f64) -DEFINE_SVE_CAST(int16_t, s16, double, f64) -DEFINE_SVE_CAST(int64_t, s64, float, f32) -DEFINE_SVE_CAST(int32_t, s32, float, f32) -DEFINE_SVE_CAST(int16_t, s16, float, f32) -DEFINE_SVE_CAST(float, f32, double, f64) +template<> +inline Vectorized cast(const Vectorized& src) { + return svreinterpret_f64_f32(src); +} -#ifdef __ARM_FEATURE_BF16 -DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16) -DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16) -DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16) -#endif // __ARM_FEATURE_BF16 +#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \ +template<> \ +inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_s##int_bit##_f##float_bit(src); \ +} \ +template<> \ +inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_f##float_bit##_s##int_bit(src); \ +} + +DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64) +DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64) +DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64) +DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32) +DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32) +DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -140,21 +143,6 @@ inline interleave2(const Vectorized& a, const Vectorized& b Vectorized(svzip2_f32(a, b))); } -#ifdef __ARM_FEATURE_BF16 -template <> -std::pair, Vectorized> -inline interleave2(const Vectorized& a, const Vectorized& b) { - // inputs: - // a = {a0, a1, a2, a3, a4, a5, a6, a7} - // b = {b0, b1, b2, b3, b4, b5, b6, b7} - // group cols crossing lanes: - // return {a0, b0, a1, b1, a2, b2, a3, b3} - // {a4, b4, a5, b5, a6, b6, a7, b7} - return std::make_pair(Vectorized(svzip1_bf16(a, b)), - Vectorized(svzip2_bf16(a, b))); -} -#endif // __ARM_FEATURE_BF16 - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template <> @@ -183,21 +171,6 @@ inline deinterleave2(const Vectorized& a, const Vectorized& Vectorized(svuzp2_f32(a, b))); } -#ifdef __ARM_FEATURE_BF16 -template <> -std::pair, Vectorized> -inline deinterleave2(const Vectorized& a, const Vectorized& b) { - // inputs: - // a = {a0, b0, a1, b1, a2, b2, a3, b3} - // b = {a4, b4, a5, b5, a6, b6, a7, b7} - // swap lanes: - // return {a0, a1, a2, a3, a4, a5, a6, a7} - // {b0, b1, b2, b3, b4, b5, b6, b7} - return std::make_pair(Vectorized(svuzp1_bf16((svbfloat16_t) a, (svbfloat16_t) b)), - Vectorized(svuzp2_bf16((svbfloat16_t) a, (svbfloat16_t) b))); -} -#endif // __ARM_FEATURE_BF16 - #endif // defined(CPU_CAPABILITY_SVE) }} diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index df066c417d6..83bb70bdbcb 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -9,16 +9,13 @@ #if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR)) #if defined(CPU_CAPABILITY_SVE256) #include -#else +#endif #include +#include +#include #include #include #include -#endif -#if !defined(CPU_CAPABILITY_SVE256) || !defined(__ARM_FEATURE_BF16) -#include -#endif -#include #include #include #elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_convert.h b/aten/src/ATen/cpu/vec/vec256/vec256_convert.h index c1d76377cbc..9dbdb4f3dfb 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_convert.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_convert.h @@ -299,46 +299,6 @@ struct VecConvert< }; #endif -#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) - -template <> -struct VecConvert { - static inline VectorizedN apply( - const VectorizedN& src) { - VectorizedN res; - // Load 16-bit unsigned integers from src into an SVE vector - svuint16_t u16x4 = svld1_u16(svptrue_b16(), reinterpret_cast(&src[0])); - // Zero-extend to 32-bit SVE does not have direct vmovl_u16 equivalent. - vls_uint32_t u32x4 = svreinterpret_u32_u16(svzip1_u16(svdup_n_u16(0), u16x4)); - // Reinterpret as float32 - vls_float32_t f32x4 = svreinterpret_f32_u32(u32x4); - res[0] = Vectorized(f32x4); - return res; - } -}; - -template <> -struct VecConvert { - static inline VectorizedN apply( - const VectorizedN& src) { - VectorizedN res; - std::tie(res[0], res[1]) = convert_bfloat16_float(src[0]); - return res; - } -}; - -template <> -struct VecConvert { - static inline VectorizedN apply( - const VectorizedN& src) { - VectorizedN res; - res[0] = convert_float_bfloat16(src[0], src[1]); - return res; - } -}; - -#endif // defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) - template struct VecConvert< float, diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index 85dce1cc263..2591338881a 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -1068,7 +1068,7 @@ inline Vectorized convert_to_int_of_same_size(const Vectorized& src) static_assert(sizeof(T) == sizeof(IntType)); static constexpr int size = Vectorized::size(); - std::array src_arr = {}; + std::array src_arr; src.store(static_cast(src_arr.data())); std::array buffer; std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(), diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index aaf2c581efe..1be4ec37dfe 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -41,11 +41,7 @@ static CPUCapability compute_cpu_capability() { #ifdef HAVE_SVE256_CPU_DEFINITION if (strcmp(envar, "sve256") == 0) { if (sve_vl == 256) { -#ifdef HAVE_SVE256_BF16_CPU_DEFINITION - return CPUCapability::SVE256_BF16; -#else return CPUCapability::SVE256; -#endif } TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT"); return CPUCapability::DEFAULT; @@ -106,11 +102,7 @@ static CPUCapability compute_cpu_capability() { } #ifdef HAVE_SVE256_CPU_DEFINITION if (sve_vl == 256) { // Check for SVE256 - #ifdef HAVE_SVE256_BF16_CPU_DEFINITION - return CPUCapability::SVE256_BF16; - #else return CPUCapability::SVE256; - #endif } #endif // Return the default CPU capability. diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 15ac7f20a00..725d0d08bae 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -63,10 +63,7 @@ enum class CPUCapability { VSX = 1, #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) ZVECTOR = 1, -#elif defined(HAVE_SVE256_BF16_CPU_DEFINITION) - SVE256 = 1, - SVE256_BF16 = 1, -#elif defined(HAVE_SVE256_CPU_DEFINITION) +#elif defined(HAVE_SVE_CPU_DEFINITION) SVE256 = 1, #else AVX2 = 1, diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 3db9646b31c..42a4d0b564b 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -274,26 +274,6 @@ inline Vectorized div_floor_floating_vec( return floordiv; } -#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) - -// Since sve lacks sufficient bf16 intrinsics, do the calculations in f32 to -// avoid rounding errors. This should not cause performance issues as -// most of the used instructions would be cast to f32 vectors anyway. -template<> -inline Vectorized div_floor_floating_vec( - const Vectorized& a, - const Vectorized& b) { - auto [a1, a2] = convert_bfloat16_float(a); - auto [b1, b2] = convert_bfloat16_float(b); - - auto res1 = div_floor_floating_vec(a1, b1); - auto res2 = div_floor_floating_vec(a2, b2); - - return convert_float_bfloat16(res1, res2); -} - -#endif - void div_floor_kernel(TensorIteratorBase& iter) { const auto dtype = iter.common_dtype(); if (dtype == kByte) { diff --git a/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp b/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp index 8a9814ac6cc..9f6f17de9a8 100644 --- a/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp +++ b/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp @@ -236,7 +236,7 @@ std::pair, vec::Vectorized> fmadd( // Return a + b_low * c_low + b_high * c_high vec::Vectorized fmadd(vec::Vectorized a, vec::Vectorized b, vec::Vectorized c) { -#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE) +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML) // NOTE: this instruction is an optional instruction in ARM v8.2 and // v8.3, but mandatory in v8.4 per // https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index d2cc963b093..4e078080090 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -566,19 +566,6 @@ namespace { } } } -#if defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16) - TEST(NanBfloat16, IsNan) { - for (unsigned int ii = 0; ii < 0xFFFF; ++ii) { - c10::BFloat16 val(ii, c10::BFloat16::from_bits()); - bool expected = std::isnan(val); - CACHE_ALIGN c10::BFloat16 actual_vals[at::vec::SVE256::Vectorized::size()]; - at::vec::SVE256::Vectorized(val).isnan().store(actual_vals); - for (int jj = 0; jj < at::vec::SVE256::Vectorized::size(); ++jj) { - EXPECT_EQ(expected, c10::bit_cast(actual_vals[jj]) != 0) << "bf16 isnan failure for bit pattern " << std::hex << ii << std::dec; - } - } - } -#endif TYPED_TEST(LGamma, LGamma) { using vec = TypeParam; using UVT = UvalueType; diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 75be5d043e4..1738a20d835 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -383,24 +383,17 @@ if(INTERN_BUILD_ATEN_OPS) LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_ZVECTOR_FLAGS}") endif(CXX_ZVECTOR_FOUND) - if(CXX_SVE_FOUND AND CXX_SVE256_FOUND) - list(APPEND CPU_CAPABILITY_NAMES "SVE256") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION") - if(CXX_SVE256_BF16_FOUND) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE256_BF16_CPU_DEFINITION") - if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang") - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256") - else() - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256") - endif() - else() - if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang") - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256") - else() - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256") - endif() - endif() - endif(CXX_SVE_FOUND AND CXX_SVE256_FOUND) + if(CXX_SVE_FOUND) + if(CXX_SVE256_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION") + list(APPEND CPU_CAPABILITY_NAMES "SVE256") + if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang") + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256") + else() + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256") + endif() + endif(CXX_SVE256_FOUND) + endif(CXX_SVE_FOUND) list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES) math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1") diff --git a/cmake/Modules/FindARM.cmake b/cmake/Modules/FindARM.cmake index e628cf7db34..340c2a64e03 100644 --- a/cmake/Modules/FindARM.cmake +++ b/cmake/Modules/FindARM.cmake @@ -106,18 +106,8 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") } ") - SET(SVE_BF16_CODE " - #include - int main() - { - svfloat32_t a = svdup_n_f32(0); - svbfloat16_t b = svreinterpret_bf16_f32(a); - return 0; - } - ") - # Macro to check for SVE instruction support - MACRO(CHECK_COMPILES lang type flags code) + MACRO(CHECK_SVE lang type flags) # Save the current state of required flags SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) @@ -152,8 +142,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") ENDMACRO() # Check for SVE256 vector length - CHECK_COMPILES(CXX "SVE256" "-march=armv8.2-a+sve -msve-vector-bits=256" SVE_CODE) - CHECK_COMPILES(CXX "SVE256_BF16" "-march=armv8.2-a+sve+bf16 -msve-vector-bits=256" SVE_BF16_CODE) + CHECK_SVE(CXX "SVE256" "-march=armv8-a+sve -msve-vector-bits=256") # If SVE256 support is not found, set CXX_SVE_FOUND to FALSE and notify the user if(NOT CXX_SVE256_FOUND) diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 5fdf6ab5212..0c8c315bbc1 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -187,19 +187,6 @@ class VecSVE256(VecISA): __hash__: Callable[[VecISA], Any] = VecISA.__hash__ -@dataclasses.dataclass -class VecSVE256_BF16(VecSVE256): - _macro = [ - "CPU_CAPABILITY_SVE", - "CPU_CAPABILITY_SVE256", - "AT_BUILD_ARM_VEC256_WITH_SLEEF", - "__ARM_FEATURE_BF16", - ] - _arch_flags = "-march=armv8-a+sve+bf16 -msve-vector-bits=256" - - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ - - @dataclasses.dataclass class VecAVX512(VecISA): _bit_width = 512 @@ -345,14 +332,7 @@ def x86_isa_checker() -> list[str]: invalid_vec_isa = InvalidVecISA() -supported_vec_isa_list = [ - VecAMX(), - VecAVX512(), - VecAVX2(), - VecNEON(), - VecSVE256(), - VecSVE256_BF16(), -] +supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE256()] def get_isa_from_cpu_capability( @@ -413,13 +393,10 @@ def valid_vec_isa_list() -> list[VecISA]: elif arch == "ppc64le": isa_list.append(VecVSX()) elif arch == "aarch64": - if torch.backends.cpu.get_cpu_capability() == "SVE256_BF16": - isa_list.append(VecSVE256_BF16()) - elif torch.backends.cpu.get_cpu_capability() == "SVE256": + if torch.backends.cpu.get_cpu_capability() == "SVE256": isa_list.append(VecSVE256()) else: isa_list.append(VecNEON()) - elif arch in ["x86_64", "AMD64"]: """ arch value is x86_64 on Linux, and the value is AMD64 on Windows.