diff --git a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal index 3e58137ca32..f7c83181a7e 100644 --- a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal @@ -4,12 +4,6 @@ using namespace metal; using namespace c10::metal; -template -T complex_div(T a, T b) { - auto denom = dot(b, b); - return T(dot(a, b), a.y * b.x - a.x * b.y) / denom; -} - struct exp_functor { template inline enable_if_t, T> operator()(const T x) { @@ -81,7 +75,7 @@ struct tan_functor { // tan(x+yi)=(tan(x) + itanh(y)) / (1 - i(tan(x) * tanh(y))) auto tan_x = precise::tan(x.x); auto tanh_y = precise::tanh(x.y); - return complex_div(T(tan_x, tanh_y), T(1, -1 * tan_x * tanh_y)); + return div(T(tan_x, tanh_y), T(1, -1 * tan_x * tanh_y)); } }; @@ -99,7 +93,7 @@ struct tanh_functor { // tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y)); auto tanh_x = precise::tanh(x.x); auto tan_y = precise::tan(x.y); - return complex_div(T(tanh_x, tan_y), T(1.0, tanh_x * tan_y)); + return div(T(tanh_x, tan_y), T(1.0, tanh_x * tan_y)); } }; @@ -144,7 +138,7 @@ struct log10_functor { auto magnitude = ::precise::sqrt(x.x * x.x + x.y * x.y); auto real = ::precise::log(magnitude); auto imag = (x.x == 0 && x.y == 0) ? 0 : ::precise::atan2(x.y, x.x); - return complex_div(T(real, imag), T(::precise::log(10), 0)); + return div(T(real, imag), T(::precise::log(10), 0)); } inline float operator()(const bool x) { return x ? 0 : -INFINITY; @@ -166,7 +160,7 @@ struct log2_functor { auto magnitude = ::precise::sqrt(x.x * x.x + x.y * x.y); auto real = ::precise::log(magnitude); auto imag = (x.x == 0 && x.y == 0) ? 0 : ::precise::atan2(x.y, x.x); - return complex_div(T(real, imag), T(::precise::log(2), 0)); + return div(T(real, imag), T(::precise::log(2), 0)); } inline float operator()(const bool x) { return x ? 0 : -INFINITY;