mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove is_reduced_floating_point from namespace std (#144502)
Partial fix for #144495. Avoiding BC-break using existing practice of removing only if FBCODE_CAFFE2 and C10_NODEPRECATED are not defined. Differential Revision: [D67992342](https://our.internmc.facebook.com/intern/diff/D67992342/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144502 Approved by: https://github.com/malfet
This commit is contained in:
parent
9a841f9321
commit
0529908f13
|
|
@ -1,21 +1,21 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec.h>
|
#include <ATen/cpu/vec/vec.h>
|
||||||
#include <c10/util/BFloat16.h> // For std::is_reduced_floating_point_v.
|
#include <c10/util/BFloat16.h> // For c10::is_reduced_floating_point_v.
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
constexpr double kGeluBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
|
constexpr double kGeluBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
|
||||||
constexpr double kGeluKappa = 0.044715;
|
constexpr double kGeluKappa = 0.044715;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using reduced_fp_to_float_t = std::conditional_t<std::is_reduced_floating_point_v<T>, float, T>;
|
using reduced_fp_to_float_t = std::conditional_t<c10::is_reduced_floating_point_v<T>, float, T>;
|
||||||
|
|
||||||
template <typename T, std::enable_if_t<std::is_reduced_floating_point_v<T>, bool> = true>
|
template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||||
float reduced_fp_to_float(T x) {
|
float reduced_fp_to_float(T x) {
|
||||||
return float(x);
|
return float(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, std::enable_if_t<!std::is_reduced_floating_point_v<T>, bool> = true>
|
template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||||
T reduced_fp_to_float(T x) {
|
T reduced_fp_to_float(T x) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
@ -29,7 +29,7 @@ T scalar_gelu_approximated_with_tanh(T x) {
|
||||||
return opmath_t(0.5) * x_float * (opmath_t(1) + std::tanh(inner));
|
return opmath_t(0.5) * x_float * (opmath_t(1) + std::tanh(inner));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, std::enable_if_t<!std::is_reduced_floating_point_v<T>, bool> = true>
|
template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||||
vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x) {
|
vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x) {
|
||||||
const vec::Vectorized<T> kPointFiveVec(T(0.5));
|
const vec::Vectorized<T> kPointFiveVec(T(0.5));
|
||||||
const vec::Vectorized<T> kOneVec(T(1));
|
const vec::Vectorized<T> kOneVec(T(1));
|
||||||
|
|
@ -40,7 +40,7 @@ vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x)
|
||||||
return kPointFiveVec * x * (kOneVec + inner_vec.tanh());
|
return kPointFiveVec * x * (kOneVec + inner_vec.tanh());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, std::enable_if_t<std::is_reduced_floating_point_v<T>, bool> = true>
|
template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||||
vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x) {
|
vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x) {
|
||||||
auto [x0, x1] = at::vec::convert_to_float<T>(x);
|
auto [x0, x1] = at::vec::convert_to_float<T>(x);
|
||||||
return at::vec::convert_from_float<T>(
|
return at::vec::convert_from_float<T>(
|
||||||
|
|
@ -56,7 +56,7 @@ T scalar_gelu(T x) {
|
||||||
return reduced_fp_to_float(x) * opmath_t(0.5) * (opmath_t(1) + std::erf(reduced_fp_to_float(x) * kAlpha));
|
return reduced_fp_to_float(x) * opmath_t(0.5) * (opmath_t(1) + std::erf(reduced_fp_to_float(x) * kAlpha));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, std::enable_if_t<!std::is_reduced_floating_point_v<T>, bool> = true>
|
template<typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||||
vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
|
vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
|
||||||
const vec::Vectorized<T> kAlphaVec(T(M_SQRT1_2));
|
const vec::Vectorized<T> kAlphaVec(T(M_SQRT1_2));
|
||||||
const vec::Vectorized<T> kOneVec(T(1));
|
const vec::Vectorized<T> kOneVec(T(1));
|
||||||
|
|
@ -64,7 +64,7 @@ vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
|
||||||
return x * kPointFiveVec * (kOneVec + (x * kAlphaVec).erf());
|
return x * kPointFiveVec * (kOneVec + (x * kAlphaVec).erf());
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, std::enable_if_t<std::is_reduced_floating_point_v<T>, bool> = true>
|
template<typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||||
vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
|
vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
|
||||||
auto [x0, x1] = at::vec::convert_to_float<T>(x);
|
auto [x0, x1] = at::vec::convert_to_float<T>(x);
|
||||||
return at::vec::convert_from_float<T>(vectorized_gelu(x0), vectorized_gelu(x1));
|
return at::vec::convert_from_float<T>(vectorized_gelu(x0), vectorized_gelu(x1));
|
||||||
|
|
|
||||||
|
|
@ -272,7 +272,7 @@ std::ostream& operator<<(std::ostream& stream, const CheckWithinDomains<T>& dmn)
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool check_both_nan([[maybe_unused]] T x, [[maybe_unused]] T y) {
|
bool check_both_nan([[maybe_unused]] T x, [[maybe_unused]] T y) {
|
||||||
if constexpr (std::is_floating_point_v<T> || std::is_reduced_floating_point_v<T>) {
|
if constexpr (std::is_floating_point_v<T> || c10::is_reduced_floating_point_v<T>) {
|
||||||
return std::isnan(x) && std::isnan(y);
|
return std::isnan(x) && std::isnan(y);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -568,7 +568,7 @@ private:
|
||||||
uint64_t seed;
|
uint64_t seed;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, bool is_floating_point = std::is_floating_point_v<T> || std::is_reduced_floating_point_v<T>, bool is_complex = is_complex<T>::value>
|
template <typename T, bool is_floating_point = std::is_floating_point_v<T> || c10::is_reduced_floating_point_v<T>, bool is_complex = is_complex<T>::value>
|
||||||
struct ValueGen
|
struct ValueGen
|
||||||
{
|
{
|
||||||
std::uniform_int_distribution<int64_t> dis;
|
std::uniform_int_distribution<int64_t> dis;
|
||||||
|
|
@ -591,7 +591,7 @@ struct ValueGen
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using reduced_fp_to_float_t = std::conditional_t<std::is_reduced_floating_point_v<T>, float, T>;
|
using reduced_fp_to_float_t = std::conditional_t<c10::is_reduced_floating_point_v<T>, float, T>;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ValueGen<T, true, false>
|
struct ValueGen<T, true, false>
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,7 @@ C10_CLANG_DIAGNOSTIC_PUSH()
|
||||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
|
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace std {
|
namespace c10 {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct is_reduced_floating_point
|
struct is_reduced_floating_point
|
||||||
: std::integral_constant<
|
: std::integral_constant<
|
||||||
|
|
@ -19,193 +18,201 @@ struct is_reduced_floating_point
|
||||||
template <typename T>
|
template <typename T>
|
||||||
constexpr bool is_reduced_floating_point_v =
|
constexpr bool is_reduced_floating_point_v =
|
||||||
is_reduced_floating_point<T>::value;
|
is_reduced_floating_point<T>::value;
|
||||||
|
} // namespace c10
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
|
||||||
|
#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED)
|
||||||
|
using c10::is_reduced_floating_point;
|
||||||
|
using c10::is_reduced_floating_point_v;
|
||||||
|
#endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED)
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T acos(T a) {
|
inline T acos(T a) {
|
||||||
return std::acos(float(a));
|
return std::acos(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T asin(T a) {
|
inline T asin(T a) {
|
||||||
return std::asin(float(a));
|
return std::asin(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T atan(T a) {
|
inline T atan(T a) {
|
||||||
return std::atan(float(a));
|
return std::atan(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T atanh(T a) {
|
inline T atanh(T a) {
|
||||||
return std::atanh(float(a));
|
return std::atanh(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T erf(T a) {
|
inline T erf(T a) {
|
||||||
return std::erf(float(a));
|
return std::erf(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T erfc(T a) {
|
inline T erfc(T a) {
|
||||||
return std::erfc(float(a));
|
return std::erfc(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T exp(T a) {
|
inline T exp(T a) {
|
||||||
return std::exp(float(a));
|
return std::exp(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T expm1(T a) {
|
inline T expm1(T a) {
|
||||||
return std::expm1(float(a));
|
return std::expm1(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline bool isfinite(T a) {
|
inline bool isfinite(T a) {
|
||||||
return std::isfinite(float(a));
|
return std::isfinite(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T log(T a) {
|
inline T log(T a) {
|
||||||
return std::log(float(a));
|
return std::log(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T log10(T a) {
|
inline T log10(T a) {
|
||||||
return std::log10(float(a));
|
return std::log10(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T log1p(T a) {
|
inline T log1p(T a) {
|
||||||
return std::log1p(float(a));
|
return std::log1p(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T log2(T a) {
|
inline T log2(T a) {
|
||||||
return std::log2(float(a));
|
return std::log2(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T ceil(T a) {
|
inline T ceil(T a) {
|
||||||
return std::ceil(float(a));
|
return std::ceil(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T cos(T a) {
|
inline T cos(T a) {
|
||||||
return std::cos(float(a));
|
return std::cos(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T floor(T a) {
|
inline T floor(T a) {
|
||||||
return std::floor(float(a));
|
return std::floor(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T nearbyint(T a) {
|
inline T nearbyint(T a) {
|
||||||
return std::nearbyint(float(a));
|
return std::nearbyint(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T sin(T a) {
|
inline T sin(T a) {
|
||||||
return std::sin(float(a));
|
return std::sin(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T tan(T a) {
|
inline T tan(T a) {
|
||||||
return std::tan(float(a));
|
return std::tan(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T sinh(T a) {
|
inline T sinh(T a) {
|
||||||
return std::sinh(float(a));
|
return std::sinh(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T cosh(T a) {
|
inline T cosh(T a) {
|
||||||
return std::cosh(float(a));
|
return std::cosh(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T tanh(T a) {
|
inline T tanh(T a) {
|
||||||
return std::tanh(float(a));
|
return std::tanh(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T trunc(T a) {
|
inline T trunc(T a) {
|
||||||
return std::trunc(float(a));
|
return std::trunc(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T lgamma(T a) {
|
inline T lgamma(T a) {
|
||||||
return std::lgamma(float(a));
|
return std::lgamma(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T sqrt(T a) {
|
inline T sqrt(T a) {
|
||||||
return std::sqrt(float(a));
|
return std::sqrt(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T rsqrt(T a) {
|
inline T rsqrt(T a) {
|
||||||
return 1.0 / std::sqrt(float(a));
|
return 1.0 / std::sqrt(float(a));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T abs(T a) {
|
inline T abs(T a) {
|
||||||
return std::abs(float(a));
|
return std::abs(float(a));
|
||||||
}
|
}
|
||||||
#if defined(_MSC_VER) && defined(__CUDACC__)
|
#if defined(_MSC_VER) && defined(__CUDACC__)
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T pow(T a, double b) {
|
inline T pow(T a, double b) {
|
||||||
return std::pow(float(a), float(b));
|
return std::pow(float(a), float(b));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T pow(T a, double b) {
|
inline T pow(T a, double b) {
|
||||||
return std::pow(float(a), b);
|
return std::pow(float(a), b);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T pow(T a, T b) {
|
inline T pow(T a, T b) {
|
||||||
return std::pow(float(a), float(b));
|
return std::pow(float(a), float(b));
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
inline T fmod(T a, T b) {
|
inline T fmod(T a, T b) {
|
||||||
return std::fmod(float(a), float(b));
|
return std::fmod(float(a), float(b));
|
||||||
}
|
}
|
||||||
|
|
@ -238,7 +245,7 @@ inline T fmod(T a, T b) {
|
||||||
*/
|
*/
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
||||||
C10_HOST_DEVICE inline T nextafter(T from, T to) {
|
C10_HOST_DEVICE inline T nextafter(T from, T to) {
|
||||||
// Reference:
|
// Reference:
|
||||||
// https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
|
// https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ static inline scalar_t* {{kernel_name}}_conditional_data_ptr(scalar_t* ptr, scal
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t,
|
template <typename scalar_t,
|
||||||
typename std::enable_if_t<std::is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
static inline scalar_t* {{kernel_name}}_conditional_data_ptr(float* ptr, scalar_t* ptr2) {
|
static inline scalar_t* {{kernel_name}}_conditional_data_ptr(float* ptr, scalar_t* ptr2) {
|
||||||
return ptr2;
|
return ptr2;
|
||||||
}
|
}
|
||||||
|
|
@ -320,7 +320,7 @@ extern "C"
|
||||||
|
|
||||||
// dtypes of kernel and internal buffers
|
// dtypes of kernel and internal buffers
|
||||||
using scalar_t = {{kernel.dtype(query)}};
|
using scalar_t = {{kernel.dtype(query)}};
|
||||||
constexpr bool is_reduced_type = std::is_reduced_floating_point_v<scalar_t>;
|
constexpr bool is_reduced_type = c10::is_reduced_floating_point_v<scalar_t>;
|
||||||
using accum_t = at::opmath_type<{{kernel.dtype(query)}}>;
|
using accum_t = at::opmath_type<{{kernel.dtype(query)}}>;
|
||||||
using Vec = at::vec::Vectorized<accum_t>;
|
using Vec = at::vec::Vectorized<accum_t>;
|
||||||
accum_t scaling_factor = {{scale}};
|
accum_t scaling_factor = {{scale}};
|
||||||
|
|
|
||||||
|
|
@ -587,13 +587,13 @@ template <> struct AsIntegerType<double> { typedef uint64_t type; };
|
||||||
template <> struct AsIntegerType<bfloat16> { typedef uint16_t type; };
|
template <> struct AsIntegerType<bfloat16> { typedef uint16_t type; };
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
typename std::enable_if_t<!std::is_reduced_floating_point_v<T>, T>
|
typename std::enable_if_t<!c10::is_reduced_floating_point_v<T>, T>
|
||||||
inline fetch_value(volatile T *addr) {
|
inline fetch_value(volatile T *addr) {
|
||||||
return *addr;
|
return *addr;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
typename std::enable_if_t<std::is_reduced_floating_point_v<T>, T>
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, T>
|
||||||
inline fetch_value(volatile T *addr) {
|
inline fetch_value(volatile T *addr) {
|
||||||
return T(addr->x, T::from_bits());
|
return T(addr->x, T::from_bits());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user