mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS][BE] Delete complex_div (#154275)
An absolute no-op: delete `complex_div` from `UnaryKernel.metal` and use identical one from `c10/metal/utils.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/154275 Approved by: https://github.com/dcci
This commit is contained in:
parent
dec6a47996
commit
6f34d141ab
|
|
@ -4,12 +4,6 @@
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
using namespace c10::metal;
|
using namespace c10::metal;
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T complex_div(T a, T b) {
|
|
||||||
auto denom = dot(b, b);
|
|
||||||
return T(dot(a, b), a.y * b.x - a.x * b.y) / denom;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct exp_functor {
|
struct exp_functor {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline enable_if_t<is_scalar_floating_point_v<T>, T> operator()(const T x) {
|
inline enable_if_t<is_scalar_floating_point_v<T>, T> operator()(const T x) {
|
||||||
|
|
@ -81,7 +75,7 @@ struct tan_functor {
|
||||||
// tan(x+yi)=(tan(x) + itanh(y)) / (1 - i(tan(x) * tanh(y)))
|
// tan(x+yi)=(tan(x) + itanh(y)) / (1 - i(tan(x) * tanh(y)))
|
||||||
auto tan_x = precise::tan(x.x);
|
auto tan_x = precise::tan(x.x);
|
||||||
auto tanh_y = precise::tanh(x.y);
|
auto tanh_y = precise::tanh(x.y);
|
||||||
return complex_div(T(tan_x, tanh_y), T(1, -1 * tan_x * tanh_y));
|
return div(T(tan_x, tanh_y), T(1, -1 * tan_x * tanh_y));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -99,7 +93,7 @@ struct tanh_functor {
|
||||||
// tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
|
// tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
|
||||||
auto tanh_x = precise::tanh(x.x);
|
auto tanh_x = precise::tanh(x.x);
|
||||||
auto tan_y = precise::tan(x.y);
|
auto tan_y = precise::tan(x.y);
|
||||||
return complex_div(T(tanh_x, tan_y), T(1.0, tanh_x * tan_y));
|
return div(T(tanh_x, tan_y), T(1.0, tanh_x * tan_y));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -144,7 +138,7 @@ struct log10_functor {
|
||||||
auto magnitude = ::precise::sqrt(x.x * x.x + x.y * x.y);
|
auto magnitude = ::precise::sqrt(x.x * x.x + x.y * x.y);
|
||||||
auto real = ::precise::log(magnitude);
|
auto real = ::precise::log(magnitude);
|
||||||
auto imag = (x.x == 0 && x.y == 0) ? 0 : ::precise::atan2(x.y, x.x);
|
auto imag = (x.x == 0 && x.y == 0) ? 0 : ::precise::atan2(x.y, x.x);
|
||||||
return complex_div(T(real, imag), T(::precise::log(10), 0));
|
return div(T(real, imag), T(::precise::log(10), 0));
|
||||||
}
|
}
|
||||||
inline float operator()(const bool x) {
|
inline float operator()(const bool x) {
|
||||||
return x ? 0 : -INFINITY;
|
return x ? 0 : -INFINITY;
|
||||||
|
|
@ -166,7 +160,7 @@ struct log2_functor {
|
||||||
auto magnitude = ::precise::sqrt(x.x * x.x + x.y * x.y);
|
auto magnitude = ::precise::sqrt(x.x * x.x + x.y * x.y);
|
||||||
auto real = ::precise::log(magnitude);
|
auto real = ::precise::log(magnitude);
|
||||||
auto imag = (x.x == 0 && x.y == 0) ? 0 : ::precise::atan2(x.y, x.x);
|
auto imag = (x.x == 0 && x.y == 0) ? 0 : ::precise::atan2(x.y, x.x);
|
||||||
return complex_div(T(real, imag), T(::precise::log(2), 0));
|
return div(T(real, imag), T(::precise::log(2), 0));
|
||||||
}
|
}
|
||||||
inline float operator()(const bool x) {
|
inline float operator()(const bool x) {
|
||||||
return x ? 0 : -INFINITY;
|
return x ? 0 : -INFINITY;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user