[MPS] Add support for i1e (#149203)

Followup after https://github.com/pytorch/pytorch/pull/149174
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149203
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-03-14 09:29:23 -07:00 committed by PyTorch MergeBot
parent f067eafabb
commit f2221b2fce
5 changed files with 57 additions and 3 deletions

View File

@ -8,6 +8,7 @@ DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j1_forward);
DEFINE_UNARY_FLOATING_FUNCTOR(i0);
DEFINE_UNARY_FLOATING_FUNCTOR(i0e);
DEFINE_UNARY_FLOATING_FUNCTOR(i1);
DEFINE_UNARY_FLOATING_FUNCTOR(i1e);
DEFINE_UNARY_FLOATING_FUNCTOR(spherical_bessel_j0);
DEFINE_UNARY_FLOATING_FUNCTOR(entr);
@ -51,6 +52,7 @@ struct bessel_y1_forward_functor {
REGISTER_UNARY_OP(i0, DTI, DTO); \
REGISTER_UNARY_OP(i0e, DTI, DTO); \
REGISTER_UNARY_OP(i1, DTI, DTO); \
REGISTER_UNARY_OP(i1e, DTI, DTO); \
REGISTER_UNARY_OP(spherical_bessel_j0, DTI, DTO); \
REGISTER_UNARY_OP(entr, DTI, DTO)

View File

@ -24,6 +24,10 @@ static void i1_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "i1");
}
static void i1e_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "i1e");
}
static void spherical_bessel_j0_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "spherical_bessel_j0");
}
@ -51,6 +55,7 @@ static void bessel_y1_kernel_mps(TensorIteratorBase& iter) {
REGISTER_DISPATCH(i0_stub, &i0_kernel_mps)
REGISTER_DISPATCH(special_i0e_stub, &i0e_kernel_mps)
REGISTER_DISPATCH(special_i1_stub, &i1_kernel_mps)
REGISTER_DISPATCH(special_i1e_stub, &i1e_kernel_mps)
REGISTER_DISPATCH(special_bessel_j0_stub, &bessel_j0_kernel_mps)
REGISTER_DISPATCH(special_bessel_j1_stub, &bessel_j1_kernel_mps)
REGISTER_DISPATCH(special_bessel_y0_stub, &bessel_y0_kernel_mps)

View File

@ -13528,7 +13528,7 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: special_i1e_out
CPU, CUDA, MPS: special_i1e_out
tags: pointwise
- func: special_logit(Tensor self, float? eps=None) -> Tensor

View File

@ -241,6 +241,54 @@ inline T i1(T _x) {
return static_cast<T>(_x < T(0.) ? -out : out);
}
template <typename T>
T i1e(T _x) {
const auto x = ::metal::fabs(_x);
if (x <= 8.0) {
// Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8].
// Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2.
constexpr float coefficients[] = {
9.38153738649577178388E-9f,
-4.44505912879632808065E-8f,
2.00329475355213526229E-7f,
-8.56872026469545474066E-7f,
3.47025130813767847674E-6f,
-1.32731636560394358279E-5f,
4.78156510755005422638E-5f,
-1.61760815825896745588E-4f,
5.12285956168575772895E-4f,
-1.51357245063125314899E-3f,
4.15642294431288815669E-3f,
-1.05640848946261981558E-2f,
2.47264490306265168283E-2f,
-5.29459812080949914269E-2f,
1.02643658689847095384E-1f,
-1.76416518357834055153E-1f,
2.52587186443633654823E-1f};
const auto y = x / 2.0 - 2.0;
const auto out = chbevl(y, coefficients, 17) * x;
return static_cast<T>(_x < 0. ? -out : out);
}
// Chebyshev coefficients for exp(-x) sqrt(x) i1(x)
// in the inverted interval (8, infinity].
// Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi).
// TODO: what's an "inverted interval"? Open on the left
// and closed on the right?
constexpr float coefficients[] = {
-3.83538038596423702205E-9f,
-2.63146884688951950684E-8f,
-2.51223623787020892529E-7f,
-3.88256480887769039346E-6f,
-1.10588938762623716291E-4f,
-9.76109749136146840777E-3f,
7.78576235018280120474E-1f};
const auto out =
chbevl(32. / x - 2., coefficients, 7) / ::metal::precise::sqrt(x);
return static_cast<T>(_x < 0. ? -out : out);
}
// gamma, lgamma
template <typename T>
inline float log_gamma(const T);

View File

@ -97,7 +97,7 @@ def mps_ops_grad_modifier(ops):
'logdet': [torch.float16, torch.float32], # missing aten::lu_solve.out
'aminmax': [torch.float32, torch.float16],
'special.i1': [torch.float16], # "i1_backward" not implemented for 'Half'
'special.i0e': None, # "special_i1e" not implemented
'special.i1e': [torch.float16], # "i1e_backward" not implemented for 'Half'
# Correctness issues
'atanh': [torch.float32],
@ -651,7 +651,6 @@ def mps_ops_modifier(ops):
'special.erfcx': None,
'special.hermite_polynomial_h': None,
'special.hermite_polynomial_he': None,
'special.i1e': None,
'special.laguerre_polynomial_l': None,
'special.log_ndtr': None,
'special.modified_bessel_i0': None,