mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][MPS] Use copysign for imaginary part of sqrt (#148286)
Also it's tempting trying to replace `a*a + b*b` with `dot(input[index])` but for some reason it results in a slightly different output Pull Request resolved: https://github.com/pytorch/pytorch/pull/148286 Approved by: https://github.com/dcci ghstack dependencies: #148285
This commit is contained in:
parent
84502baaff
commit
3ca1a2564d
|
|
@ -56,17 +56,11 @@ kernel void sqrt_complex_kernel(
|
|||
T0 b = input[index].y;
|
||||
|
||||
// modulus
|
||||
T0 r = T0(precise::sqrt(a * a + b * b));
|
||||
// real part: sqrt((r + a)/2)
|
||||
T0 real_part = T0(precise::sqrt((r + a) / 2.0));
|
||||
|
||||
// imaginary part: sign(b) * sqrt((r - a)/2)
|
||||
T0 imag_part;
|
||||
if (b >= 0) {
|
||||
imag_part = T0(precise::sqrt((r - a) / 2.0));
|
||||
} else {
|
||||
imag_part = T0(-precise::sqrt((r - a) / 2.0));
|
||||
}
|
||||
auto m = precise::sqrt(a * a + b * b);
|
||||
// real part: sqrt((m + a)/2)
|
||||
auto real_part = precise::sqrt((m + a) * .5);
|
||||
// imaginary part: sign(b) * sqrt((m - a)/2)
|
||||
auto imag_part = copysign(static_cast<T0>(precise::sqrt((m - a) * .5)), b);
|
||||
output[index] = vec2type_t<T0>(real_part, imag_part);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user