[MPS] Add modified_bessel_k0 support to eager. (#149563)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149563
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-03-19 23:10:55 +00:00 committed by PyTorch MergeBot
parent bc86b6c55a
commit 88c2fe533f
5 changed files with 74 additions and 2 deletions

View File

@ -7,6 +7,7 @@ DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j0_forward);
DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j1_forward);
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_i0_forward);
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_i1_forward);
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_k0_forward);
DEFINE_UNARY_FLOATING_FUNCTOR(i0);
DEFINE_UNARY_FLOATING_FUNCTOR(i0e);
DEFINE_UNARY_FLOATING_FUNCTOR(i1);
@ -51,6 +52,7 @@ struct bessel_y1_forward_functor {
REGISTER_UNARY_OP(bessel_j1_forward, DTI, DTO); \
REGISTER_UNARY_OP(modified_bessel_i0_forward, DTI, DTO); \
REGISTER_UNARY_OP(modified_bessel_i1_forward, DTI, DTO); \
REGISTER_UNARY_OP(modified_bessel_k0_forward, DTI, DTO); \
REGISTER_UNARY_OP(bessel_y0_forward, DTI, DTO); \
REGISTER_UNARY_OP(bessel_y1_forward, DTI, DTO); \
REGISTER_UNARY_OP(i0, DTI, DTO); \

View File

@ -52,6 +52,10 @@ static void modified_bessel_i1_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "modified_bessel_i1_forward");
}
static void modified_bessel_k0_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "modified_bessel_k0_forward");
}
static void bessel_y0_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "bessel_y0_forward");
}
@ -68,6 +72,7 @@ REGISTER_DISPATCH(special_bessel_j0_stub, &bessel_j0_kernel_mps)
REGISTER_DISPATCH(special_bessel_j1_stub, &bessel_j1_kernel_mps)
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &modified_bessel_i0_kernel_mps)
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &modified_bessel_i1_kernel_mps)
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &modified_bessel_k0_kernel_mps)
REGISTER_DISPATCH(special_bessel_y0_stub, &bessel_y0_kernel_mps)
REGISTER_DISPATCH(special_bessel_y1_stub, &bessel_y1_kernel_mps)
REGISTER_DISPATCH(special_spherical_bessel_j0_stub, &spherical_bessel_j0_kernel_mps)

View File

@ -15470,7 +15470,7 @@
- func: special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: special_modified_bessel_k0_out
CPU, CUDA, MPS: special_modified_bessel_k0_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True

View File

@ -1254,5 +1254,71 @@ inline float modified_bessel_i1_forward(T x) {
::metal::precise::sqrt(::metal::fabs(x));
} // modified_bessel_i1_forward(T x)
template <typename T>
inline float modified_bessel_k0_forward(T x) {
constexpr float A[] = {
+1.37446543561352307156e-16,
+4.25981614279661018399e-14,
+1.03496952576338420167e-11,
+1.90451637722020886025e-09,
+2.53479107902614945675e-07,
+2.28621210311945178607e-05,
+1.26461541144692592338e-03,
+3.59799365153615016266e-02,
+3.44289899924628486886e-01,
-5.35327393233902768720e-01,
};
constexpr float B[] = {
+5.30043377268626276149e-18, -1.64758043015242134646e-17,
+5.21039150503902756861e-17, -1.67823109680541210385e-16,
+5.51205597852431940784e-16, -1.84859337734377901440e-15,
+6.34007647740507060557e-15, -2.22751332699166985548e-14,
+8.03289077536357521100e-14, -2.98009692317273043925e-13,
+1.14034058820847496303e-12, -4.51459788337394416547e-12,
+1.85594911495471785253e-11, -7.95748924447710747776e-11,
+3.57739728140030116597e-10, -1.69753450938905987466e-09,
+8.57403401741422608519e-09, -4.66048989768794782956e-08,
+2.76681363944501510342e-07, -1.83175552271911948767e-06,
+1.39498137188764993662e-05, -1.28495495816278026384e-04,
+1.56988388573005337491e-03, -3.14481013119645005427e-02,
+2.44030308206595545468e+00,
};
if (x == 0.0) {
return INFINITY;
}
if (x < 0.0) {
return NAN;
}
float p;
float q = 0.0;
if (x <= 2.0) {
float a = A[0];
for (uint8_t index = 1; index < 10; index++) {
p = q;
q = a;
a = (x * x - 2.0) * q - p + A[index];
}
return 0.5 * (a - p) -
::metal::log(0.5 * x) * modified_bessel_i0_forward(x);
}
float b = B[0];
for (uint8_t index = 1; index < 25; index++) {
p = q;
q = b;
b = (8.0 / x - 2.0) * q - p + B[index];
}
return ::metal::exp(-x) * (0.5 * (b - p)) / ::metal::sqrt(x);
} // modified_bessel_k0_forward(T x)
} // namespace metal
} // namespace c10

View File

@ -663,7 +663,6 @@ def mps_ops_modifier(ops):
'special.hermite_polynomial_he': None,
'special.laguerre_polynomial_l': None,
'special.log_ndtr': None,
'special.modified_bessel_k0': None,
'special.modified_bessel_k1': None,
'special.ndtri': None,
'special.scaled_modified_bessel_k0': None,