mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Improve ROCm's sqrt and rsqrt for std::complex.
This commit is contained in:
parent
f3b556a903
commit
f3bbee405a
|
|
@ -148,41 +148,30 @@ __device__ Eigen::half impl_rsqrt(Eigen::half x) {
|
|||
|
||||
template <class T>
|
||||
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T mod_x = sqrt(re * re + im * im);
|
||||
const T root2 = 0.7071067811865475;
|
||||
T a = x.real();
|
||||
T b = x.imag();
|
||||
T r = impl_sqrt(norm(x));
|
||||
// returns sqrt(0.5 * (r - v)) where v may be close to r.
|
||||
auto helper = [&](const T& v) {
|
||||
T diff = r - v;
|
||||
if (diff < T(1e-5) * r) {
|
||||
// |a| >> |b|, use rsqrt(1+x) ~= 1 + x/2.
|
||||
return T(0.5) * fabs(b) * impl_rsqrt(r);
|
||||
}
|
||||
return sqrt(T(0.5) * diff);
|
||||
};
|
||||
// We pick the root with the same sign of the imaginary component as
|
||||
// the input.
|
||||
T root[2] = {T(sqrt(mod_x + re) * root2),
|
||||
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
|
||||
T result[2] = {helper(-a),
|
||||
[&](const T& v) { return b >= 0 ? v : -v; }(helper(a))};
|
||||
// hcc/clang is really weird with its support of complex in device code;
|
||||
// for some reason it does not permit a 2-argument constructor
|
||||
return *(reinterpret_cast<std::complex<T>*>(&root));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ T rsqrt_helper(T x) {
|
||||
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
|
||||
return reinterpret_cast<const std::complex<T>&>(result);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T r = rsqrt(re * re + im * im);
|
||||
T ar2 = re * r * r;
|
||||
const T root2 = 0.7071067811865475;
|
||||
T root[2];
|
||||
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
|
||||
// due to subtraction of two close values. We have to get fancy
|
||||
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
|
||||
? rsqrt_helper(im * im * r * r)
|
||||
: 1 + re * r)) *
|
||||
root2;
|
||||
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
|
||||
? rsqrt_helper(im * im * r * r)
|
||||
: 1 - re * r)) *
|
||||
root2 * (im >= 0 ? -1. : 1.);
|
||||
return *(reinterpret_cast<std::complex<T>*>(&root));
|
||||
return conj(impl_sqrt(x)) * impl_rsqrt(norm(x));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user