[MPS][BE] Delete complex_div (#154275)

An absolute no-op: delete `complex_div` from `UnaryKernel.metal` and use identical one from `c10/metal/utils.h`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154275
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-05-23 13:13:31 -07:00 committed by PyTorch MergeBot
parent dec6a47996
commit 6f34d141ab

View File

@ -4,12 +4,6 @@
using namespace metal; using namespace metal;
using namespace c10::metal; using namespace c10::metal;
template <typename T>
T complex_div(T a, T b) {
auto denom = dot(b, b);
return T(dot(a, b), a.y * b.x - a.x * b.y) / denom;
}
struct exp_functor { struct exp_functor {
template <typename T> template <typename T>
inline enable_if_t<is_scalar_floating_point_v<T>, T> operator()(const T x) { inline enable_if_t<is_scalar_floating_point_v<T>, T> operator()(const T x) {
@ -81,7 +75,7 @@ struct tan_functor {
// tan(x+yi)=(tan(x) + itanh(y)) / (1 - i(tan(x) * tanh(y))) // tan(x+yi)=(tan(x) + itanh(y)) / (1 - i(tan(x) * tanh(y)))
auto tan_x = precise::tan(x.x); auto tan_x = precise::tan(x.x);
auto tanh_y = precise::tanh(x.y); auto tanh_y = precise::tanh(x.y);
return complex_div(T(tan_x, tanh_y), T(1, -1 * tan_x * tanh_y)); return div(T(tan_x, tanh_y), T(1, -1 * tan_x * tanh_y));
} }
}; };
@ -99,7 +93,7 @@ struct tanh_functor {
// tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y)); // tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
auto tanh_x = precise::tanh(x.x); auto tanh_x = precise::tanh(x.x);
auto tan_y = precise::tan(x.y); auto tan_y = precise::tan(x.y);
return complex_div(T(tanh_x, tan_y), T(1.0, tanh_x * tan_y)); return div(T(tanh_x, tan_y), T(1.0, tanh_x * tan_y));
} }
}; };
@ -144,7 +138,7 @@ struct log10_functor {
auto magnitude = ::precise::sqrt(x.x * x.x + x.y * x.y); auto magnitude = ::precise::sqrt(x.x * x.x + x.y * x.y);
auto real = ::precise::log(magnitude); auto real = ::precise::log(magnitude);
auto imag = (x.x == 0 && x.y == 0) ? 0 : ::precise::atan2(x.y, x.x); auto imag = (x.x == 0 && x.y == 0) ? 0 : ::precise::atan2(x.y, x.x);
return complex_div(T(real, imag), T(::precise::log(10), 0)); return div(T(real, imag), T(::precise::log(10), 0));
} }
inline float operator()(const bool x) { inline float operator()(const bool x) {
return x ? 0 : -INFINITY; return x ? 0 : -INFINITY;
@ -166,7 +160,7 @@ struct log2_functor {
auto magnitude = ::precise::sqrt(x.x * x.x + x.y * x.y); auto magnitude = ::precise::sqrt(x.x * x.x + x.y * x.y);
auto real = ::precise::log(magnitude); auto real = ::precise::log(magnitude);
auto imag = (x.x == 0 && x.y == 0) ? 0 : ::precise::atan2(x.y, x.x); auto imag = (x.x == 0 && x.y == 0) ? 0 : ::precise::atan2(x.y, x.x);
return complex_div(T(real, imag), T(::precise::log(2), 0)); return div(T(real, imag), T(::precise::log(2), 0));
} }
inline float operator()(const bool x) { inline float operator()(const bool x) {
return x ? 0 : -INFINITY; return x ? 0 : -INFINITY;