diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index 92496e63e1a..2fd2dbe90bb 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -148,41 +148,30 @@ __device__ Eigen::half impl_rsqrt(Eigen::half x) { template __device__ std::complex impl_sqrt(std::complex x) { - T re = x.real(), im = x.imag(); - T mod_x = sqrt(re * re + im * im); - const T root2 = 0.7071067811865475; + T a = x.real(); + T b = x.imag(); + T r = impl_sqrt(norm(x)); + // returns sqrt(0.5 * (r - v)) where v may be close to r. + auto helper = [&](const T& v) { + T diff = r - v; + if (diff < T(1e-5) * r) { + // |a| >> |b|, use rsqrt(1+x) ~= 1 + x/2. + return T(0.5) * fabs(b) * impl_rsqrt(r); + } + return sqrt(T(0.5) * diff); + }; // We pick the root with the same sign of the imaginary component as // the input. - T root[2] = {T(sqrt(mod_x + re) * root2), - T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))}; + T result[2] = {helper(-a), + [&](const T& v) { return b >= 0 ? v : -v; }(helper(a))}; // hcc/clang is really weird with its support of complex in device code; // for some reason it does not permit a 2-argument constructor - return *(reinterpret_cast*>(&root)); -} - -template -__device__ T rsqrt_helper(T x) { - return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x; + return reinterpret_cast&>(result); } template __device__ std::complex impl_rsqrt(std::complex x) { - T re = x.real(), im = x.imag(); - T r = rsqrt(re * re + im * im); - T ar2 = re * r * r; - const T root2 = 0.7071067811865475; - T root[2]; - // With float, calculating 1+re*r and 1-re*r may result in excessive errors - // due to subtraction of two close values. We have to get fancy - root[0] = sqrt(r * ((std::is_same::value && re * r < -0.98) - ? rsqrt_helper(im * im * r * r) - : 1 + re * r)) * - root2; - root[1] = sqrt(r * ((std::is_same::value && re * r > 0.98) - ? rsqrt_helper(im * im * r * r) - : 1 - re * r)) * - root2 * (im >= 0 ? -1. : 1.); - return *(reinterpret_cast*>(&root)); + return conj(impl_sqrt(x)) * impl_rsqrt(norm(x)); } template