mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: PR allows compiler to better optimize some bfloat16-based operations, when ran on NEON Retrying to land the code, after noting that these expressions became available in recent compiler versions. Current CI benchmark binary_test.py will measure affected codepaths. Benchmarks show measurable improvements on clang-19, when targeting armv9-a+sve2: Before: bfloat16 add: 250.503us bfloat16 sub: 245.674us bfloat16 neg: 113.945us bfloat16 abs: 115.953us bfloat16 reciprocal: 262.602us After: bfloat16 add: 203.862us ---> 23% higher throughput bfloat16 sub: 201.526us ---> 22% higher throughput bfloat16 neg: 68.416us ---> 67% higher throughput bfloat16 abs: 71.003us ---> 63% higher throughput bfloat16 reciprocal: 177.834us ---> 48% higher throughput Test Plan: Correctness: buck2 test mode/opt //caffe2/test:test_ops buck2 test mode/opt //caffe2/test:torch Performance: buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test Reviewed By: mcfi Differential Revision: D85809843 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166641 Approved by: https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
parent
0947765eb9
commit
b71966f67b
|
|
@ -19,6 +19,13 @@ inline namespace CPU_CAPABILITY {
|
|||
#error "Big endian is not supported."
|
||||
#endif
|
||||
|
||||
// GCC does not properly optimize bf16 operators
|
||||
#if defined(__ARM_FEATURE_BF16) && (__clang_major__ >= 19)
|
||||
#define BF16_ARITHMETIC_SUPPORTED() 1
|
||||
#else
|
||||
#define BF16_ARITHMETIC_SUPPORTED() 0
|
||||
#endif
|
||||
|
||||
// Unlike the float16_t family of types, bfloat16_t is not available
|
||||
// when we're not targeting bfloat16 hardware support on some
|
||||
// platforms (but not Mac, so we have to be careful not to shadow the
|
||||
|
|
@ -352,18 +359,72 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
|||
other, &Vectorized<float>::name); \
|
||||
}
|
||||
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
||||
Vectorized frac() const;
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
// Flip sign bit
|
||||
Vectorized<c10::BFloat16> neg() const {
|
||||
return vreinterpretq_bf16_s16(vreinterpretq_s16_bf16(values) ^ (-32768));
|
||||
}
|
||||
// Fast reciprocal is fine because we are truncating results
|
||||
Vectorized<c10::BFloat16> reciprocal() const {
|
||||
auto x = vcvtq_low_f32_bf16(values);
|
||||
auto y = vcvtq_high_f32_bf16(values);
|
||||
x = vrecpeq_f32(x);
|
||||
y = vrecpeq_f32(y);
|
||||
return vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(x), y);
|
||||
}
|
||||
// Clearing the sign bit
|
||||
Vectorized<c10::BFloat16> abs() const {
|
||||
return vreinterpretq_bf16_u16(vreinterpretq_u16_bf16(values) & 0x7FFF);
|
||||
}
|
||||
#else
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
|
||||
#endif
|
||||
|
||||
// These functions are optimized on clang-21+
|
||||
#if BF16_ARITHMETIC_SUPPORTED() && (__clang_major__ >= 21)
|
||||
Vectorized<c10::BFloat16> operator==(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values == other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator!=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values != other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator<(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values < other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator<=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values <= other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator>(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values > other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator>=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values >= other.values;
|
||||
}
|
||||
#else
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
|
||||
#endif
|
||||
|
||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||
|
|
@ -412,28 +473,52 @@ template <>
|
|||
Vectorized<c10::BFloat16> inline operator+(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x + y;
|
||||
#else
|
||||
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator-(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x - y;
|
||||
#else
|
||||
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator*(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x * y;
|
||||
#else
|
||||
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator/(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x / y;
|
||||
#else
|
||||
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
|
|
@ -544,12 +629,19 @@ Vectorized<c10::BFloat16> inline fmadd(
|
|||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return x * y + z;
|
||||
#else
|
||||
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
|
||||
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
|
||||
// elements, not the bottom and top half, so they don't seem
|
||||
// particularly useful here. Ideally we would include dot product in
|
||||
// the Vectorized interface...
|
||||
return a * b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
@ -557,8 +649,15 @@ Vectorized<c10::BFloat16> inline fnmadd(
|
|||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return (-x) * y + z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return -a * b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
@ -566,8 +665,15 @@ Vectorized<c10::BFloat16> inline fmsub(
|
|||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return x * y - z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
@ -575,8 +681,15 @@ Vectorized<c10::BFloat16> inline fnmsub(
|
|||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return (-x) * y - z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return -a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user