mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Add support for i1e (#149203)
Followup after https://github.com/pytorch/pytorch/pull/149174 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149203 Approved by: https://github.com/dcci
This commit is contained in:
parent
f067eafabb
commit
f2221b2fce
|
|
@ -8,6 +8,7 @@ DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j1_forward);
|
|||
DEFINE_UNARY_FLOATING_FUNCTOR(i0);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(i0e);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(i1);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(i1e);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(spherical_bessel_j0);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(entr);
|
||||
|
||||
|
|
@ -51,6 +52,7 @@ struct bessel_y1_forward_functor {
|
|||
REGISTER_UNARY_OP(i0, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(i0e, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(i1, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(i1e, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(spherical_bessel_j0, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(entr, DTI, DTO)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ static void i1_kernel_mps(TensorIteratorBase& iter) {
|
|||
lib.exec_unary_kernel(iter, "i1");
|
||||
}
|
||||
|
||||
static void i1e_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "i1e");
|
||||
}
|
||||
|
||||
static void spherical_bessel_j0_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "spherical_bessel_j0");
|
||||
}
|
||||
|
|
@ -51,6 +55,7 @@ static void bessel_y1_kernel_mps(TensorIteratorBase& iter) {
|
|||
REGISTER_DISPATCH(i0_stub, &i0_kernel_mps)
|
||||
REGISTER_DISPATCH(special_i0e_stub, &i0e_kernel_mps)
|
||||
REGISTER_DISPATCH(special_i1_stub, &i1_kernel_mps)
|
||||
REGISTER_DISPATCH(special_i1e_stub, &i1e_kernel_mps)
|
||||
REGISTER_DISPATCH(special_bessel_j0_stub, &bessel_j0_kernel_mps)
|
||||
REGISTER_DISPATCH(special_bessel_j1_stub, &bessel_j1_kernel_mps)
|
||||
REGISTER_DISPATCH(special_bessel_y0_stub, &bessel_y0_kernel_mps)
|
||||
|
|
|
|||
|
|
@ -13528,7 +13528,7 @@
|
|||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA: special_i1e_out
|
||||
CPU, CUDA, MPS: special_i1e_out
|
||||
tags: pointwise
|
||||
|
||||
- func: special_logit(Tensor self, float? eps=None) -> Tensor
|
||||
|
|
|
|||
|
|
@ -241,6 +241,54 @@ inline T i1(T _x) {
|
|||
return static_cast<T>(_x < T(0.) ? -out : out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T i1e(T _x) {
|
||||
const auto x = ::metal::fabs(_x);
|
||||
if (x <= 8.0) {
|
||||
// Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8].
|
||||
// Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2.
|
||||
constexpr float coefficients[] = {
|
||||
9.38153738649577178388E-9f,
|
||||
-4.44505912879632808065E-8f,
|
||||
2.00329475355213526229E-7f,
|
||||
-8.56872026469545474066E-7f,
|
||||
3.47025130813767847674E-6f,
|
||||
-1.32731636560394358279E-5f,
|
||||
4.78156510755005422638E-5f,
|
||||
-1.61760815825896745588E-4f,
|
||||
5.12285956168575772895E-4f,
|
||||
-1.51357245063125314899E-3f,
|
||||
4.15642294431288815669E-3f,
|
||||
-1.05640848946261981558E-2f,
|
||||
2.47264490306265168283E-2f,
|
||||
-5.29459812080949914269E-2f,
|
||||
1.02643658689847095384E-1f,
|
||||
-1.76416518357834055153E-1f,
|
||||
2.52587186443633654823E-1f};
|
||||
const auto y = x / 2.0 - 2.0;
|
||||
const auto out = chbevl(y, coefficients, 17) * x;
|
||||
return static_cast<T>(_x < 0. ? -out : out);
|
||||
}
|
||||
|
||||
// Chebyshev coefficients for exp(-x) sqrt(x) i1(x)
|
||||
// in the inverted interval (8, infinity].
|
||||
// Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi).
|
||||
// TODO: what's an "inverted interval"? Open on the left
|
||||
// and closed on the right?
|
||||
constexpr float coefficients[] = {
|
||||
-3.83538038596423702205E-9f,
|
||||
-2.63146884688951950684E-8f,
|
||||
-2.51223623787020892529E-7f,
|
||||
-3.88256480887769039346E-6f,
|
||||
-1.10588938762623716291E-4f,
|
||||
-9.76109749136146840777E-3f,
|
||||
7.78576235018280120474E-1f};
|
||||
|
||||
const auto out =
|
||||
chbevl(32. / x - 2., coefficients, 7) / ::metal::precise::sqrt(x);
|
||||
return static_cast<T>(_x < 0. ? -out : out);
|
||||
}
|
||||
|
||||
// gamma, lgamma
|
||||
template <typename T>
|
||||
inline float log_gamma(const T);
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ def mps_ops_grad_modifier(ops):
|
|||
'logdet': [torch.float16, torch.float32], # missing aten::lu_solve.out
|
||||
'aminmax': [torch.float32, torch.float16],
|
||||
'special.i1': [torch.float16], # "i1_backward" not implemented for 'Half'
|
||||
'special.i0e': None, # "special_i1e" not implemented
|
||||
'special.i1e': [torch.float16], # "i1e_backward" not implemented for 'Half'
|
||||
|
||||
# Correctness issues
|
||||
'atanh': [torch.float32],
|
||||
|
|
@ -651,7 +651,6 @@ def mps_ops_modifier(ops):
|
|||
'special.erfcx': None,
|
||||
'special.hermite_polynomial_h': None,
|
||||
'special.hermite_polynomial_he': None,
|
||||
'special.i1e': None,
|
||||
'special.laguerre_polynomial_l': None,
|
||||
'special.log_ndtr': None,
|
||||
'special.modified_bessel_i0': None,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user