[BE][MPS] Use copysign for imaginary part of sqrt (#148286)

Also it's tempting trying to replace `a*a + b*b` with `dot(input[index])` but for some reason it results in a slightly different output
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148286
Approved by: https://github.com/dcci
ghstack dependencies: #148285
This commit is contained in:
Nikita Shulga 2025-03-02 13:57:44 -08:00 committed by PyTorch MergeBot
parent 84502baaff
commit 3ca1a2564d

View File

@ -56,17 +56,11 @@ kernel void sqrt_complex_kernel(
T0 b = input[index].y;
// modulus
T0 r = T0(precise::sqrt(a * a + b * b));
// real part: sqrt((r + a)/2)
T0 real_part = T0(precise::sqrt((r + a) / 2.0));
// imaginary part: sign(b) * sqrt((r - a)/2)
T0 imag_part;
if (b >= 0) {
imag_part = T0(precise::sqrt((r - a) / 2.0));
} else {
imag_part = T0(-precise::sqrt((r - a) / 2.0));
}
auto m = precise::sqrt(a * a + b * b);
// real part: sqrt((m + a)/2)
auto real_part = precise::sqrt((m + a) * .5);
// imaginary part: sign(b) * sqrt((m - a)/2)
auto imag_part = copysign(static_cast<T0>(precise::sqrt((m - a) * .5)), b);
output[index] = vec2type_t<T0>(real_part, imag_part);
}