mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Add modified_bessel_k0 support to eager. (#149563)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149563 Approved by: https://github.com/malfet
This commit is contained in:
parent
bc86b6c55a
commit
88c2fe533f
|
|
@ -7,6 +7,7 @@ DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j0_forward);
|
|||
DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j1_forward);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_i0_forward);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_i1_forward);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_k0_forward);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(i0);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(i0e);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(i1);
|
||||
|
|
@ -51,6 +52,7 @@ struct bessel_y1_forward_functor {
|
|||
REGISTER_UNARY_OP(bessel_j1_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(modified_bessel_i0_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(modified_bessel_i1_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(modified_bessel_k0_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(bessel_y0_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(bessel_y1_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(i0, DTI, DTO); \
|
||||
|
|
|
|||
|
|
@ -52,6 +52,10 @@ static void modified_bessel_i1_kernel_mps(TensorIteratorBase& iter) {
|
|||
lib.exec_unary_kernel(iter, "modified_bessel_i1_forward");
|
||||
}
|
||||
|
||||
static void modified_bessel_k0_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "modified_bessel_k0_forward");
|
||||
}
|
||||
|
||||
static void bessel_y0_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "bessel_y0_forward");
|
||||
}
|
||||
|
|
@ -68,6 +72,7 @@ REGISTER_DISPATCH(special_bessel_j0_stub, &bessel_j0_kernel_mps)
|
|||
REGISTER_DISPATCH(special_bessel_j1_stub, &bessel_j1_kernel_mps)
|
||||
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &modified_bessel_i0_kernel_mps)
|
||||
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &modified_bessel_i1_kernel_mps)
|
||||
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &modified_bessel_k0_kernel_mps)
|
||||
REGISTER_DISPATCH(special_bessel_y0_stub, &bessel_y0_kernel_mps)
|
||||
REGISTER_DISPATCH(special_bessel_y1_stub, &bessel_y1_kernel_mps)
|
||||
REGISTER_DISPATCH(special_spherical_bessel_j0_stub, &spherical_bessel_j0_kernel_mps)
|
||||
|
|
|
|||
|
|
@ -15470,7 +15470,7 @@
|
|||
|
||||
- func: special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: special_modified_bessel_k0_out
|
||||
CPU, CUDA, MPS: special_modified_bessel_k0_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
|
|
|
|||
|
|
@ -1254,5 +1254,71 @@ inline float modified_bessel_i1_forward(T x) {
|
|||
::metal::precise::sqrt(::metal::fabs(x));
|
||||
} // modified_bessel_i1_forward(T x)
|
||||
|
||||
template <typename T>
|
||||
inline float modified_bessel_k0_forward(T x) {
|
||||
constexpr float A[] = {
|
||||
+1.37446543561352307156e-16,
|
||||
+4.25981614279661018399e-14,
|
||||
+1.03496952576338420167e-11,
|
||||
+1.90451637722020886025e-09,
|
||||
+2.53479107902614945675e-07,
|
||||
+2.28621210311945178607e-05,
|
||||
+1.26461541144692592338e-03,
|
||||
+3.59799365153615016266e-02,
|
||||
+3.44289899924628486886e-01,
|
||||
-5.35327393233902768720e-01,
|
||||
};
|
||||
|
||||
constexpr float B[] = {
|
||||
+5.30043377268626276149e-18, -1.64758043015242134646e-17,
|
||||
+5.21039150503902756861e-17, -1.67823109680541210385e-16,
|
||||
+5.51205597852431940784e-16, -1.84859337734377901440e-15,
|
||||
+6.34007647740507060557e-15, -2.22751332699166985548e-14,
|
||||
+8.03289077536357521100e-14, -2.98009692317273043925e-13,
|
||||
+1.14034058820847496303e-12, -4.51459788337394416547e-12,
|
||||
+1.85594911495471785253e-11, -7.95748924447710747776e-11,
|
||||
+3.57739728140030116597e-10, -1.69753450938905987466e-09,
|
||||
+8.57403401741422608519e-09, -4.66048989768794782956e-08,
|
||||
+2.76681363944501510342e-07, -1.83175552271911948767e-06,
|
||||
+1.39498137188764993662e-05, -1.28495495816278026384e-04,
|
||||
+1.56988388573005337491e-03, -3.14481013119645005427e-02,
|
||||
+2.44030308206595545468e+00,
|
||||
};
|
||||
|
||||
if (x == 0.0) {
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
if (x < 0.0) {
|
||||
return NAN;
|
||||
}
|
||||
|
||||
float p;
|
||||
float q = 0.0;
|
||||
|
||||
if (x <= 2.0) {
|
||||
float a = A[0];
|
||||
|
||||
for (uint8_t index = 1; index < 10; index++) {
|
||||
p = q;
|
||||
q = a;
|
||||
a = (x * x - 2.0) * q - p + A[index];
|
||||
}
|
||||
|
||||
return 0.5 * (a - p) -
|
||||
::metal::log(0.5 * x) * modified_bessel_i0_forward(x);
|
||||
}
|
||||
|
||||
float b = B[0];
|
||||
|
||||
for (uint8_t index = 1; index < 25; index++) {
|
||||
p = q;
|
||||
q = b;
|
||||
b = (8.0 / x - 2.0) * q - p + B[index];
|
||||
}
|
||||
|
||||
return ::metal::exp(-x) * (0.5 * (b - p)) / ::metal::sqrt(x);
|
||||
} // modified_bessel_k0_forward(T x)
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -663,7 +663,6 @@ def mps_ops_modifier(ops):
|
|||
'special.hermite_polynomial_he': None,
|
||||
'special.laguerre_polynomial_l': None,
|
||||
'special.log_ndtr': None,
|
||||
'special.modified_bessel_k0': None,
|
||||
'special.modified_bessel_k1': None,
|
||||
'special.ndtri': None,
|
||||
'special.scaled_modified_bessel_k0': None,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user