[MPS] Add support for entr() in eager. (#147948)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147948
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-02-26 19:55:02 +00:00 committed by PyTorch MergeBot
parent eb08ada5d3
commit 683e083e8d
5 changed files with 30 additions and 12 deletions

View File

@ -25,13 +25,23 @@ void kernel spherical_bessel_j0(
c10::metal::spherical_bessel_j0(static_cast<T0>(input[index]));
}
#define REGISTER_I0_I1(DTI, DTO) \
template [[host_name("i0_" #DTO "_" #DTI)]] void kernel i0<DTO, DTI>( \
device DTO*, constant DTI*, uint); \
template [[host_name("i1_" #DTO "_" #DTI)]] void kernel i1<DTO, DTI>( \
device DTO*, constant DTI*, uint); \
template [[host_name("spherical_bessel_j0_" #DTO "_" #DTI)]] void kernel \
spherical_bessel_j0<DTO, DTI>(device DTO*, constant DTI*, uint);
template <typename T0, typename T1>
void kernel entr(
device T0* output,
constant T1* input,
uint index [[thread_position_in_grid]]) {
output[index] = c10::metal::entr(static_cast<T0>(input[index]));
}
#define REGISTER_I0_I1(DTI, DTO) \
template [[host_name("i0_" #DTO "_" #DTI)]] void kernel i0<DTO, DTI>( \
device DTO*, constant DTI*, uint); \
template [[host_name("i1_" #DTO "_" #DTI)]] void kernel i1<DTO, DTI>( \
device DTO*, constant DTI*, uint); \
template [[host_name("spherical_bessel_j0_" #DTO "_" #DTI)]] void kernel \
spherical_bessel_j0<DTO, DTI>(device DTO*, constant DTI*, uint); \
template [[host_name("entr_" #DTO "_" #DTI)]] void kernel entr<DTO, DTI>( \
device DTO*, constant DTI*, uint);
REGISTER_I0_I1(float, float);
REGISTER_I0_I1(bool, float);

View File

@ -24,7 +24,12 @@ static void spherical_bessel_j0_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "spherical_bessel_j0");
}
static void entr_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "entr");
}
REGISTER_DISPATCH(i0_stub, &i0_kernel_mps)
REGISTER_DISPATCH(special_i1_stub, &i1_kernel_mps)
REGISTER_DISPATCH(special_spherical_bessel_j0_stub, &spherical_bessel_j0_kernel_mps)
REGISTER_DISPATCH(special_entr_stub, &entr_kernel_mps)
} // namespace at::native

View File

@ -13230,7 +13230,7 @@
python_module: special
variants: function
dispatch:
CPU, CUDA: special_entr_out
CPU, CUDA, MPS: special_entr_out
tags: pointwise
- func: special_ndtri(Tensor self) -> Tensor

View File

@ -531,20 +531,20 @@ inline float xlog1py(T x, T y) {
}
template <typename T>
inline float entr(T a) {
inline T entr(T a) {
if (a != a) {
return a;
}
if (a > 0) {
return -a * ::metal::log(a);
return static_cast<T>(-a * ::metal::log(a));
}
if (a == 0) {
return 0;
}
return -INFINITY;
return static_cast<T>(-INFINITY);
}
} // namespace metal

View File

@ -325,6 +325,7 @@ def mps_ops_modifier(ops):
'sinc',
'slice',
'special.spherical_bessel_j0',
'special.entr',
'special.xlog1py',
'special.zeta',
'split',
@ -649,7 +650,6 @@ def mps_ops_modifier(ops):
'special.bessel_y1': None,
'special.chebyshev_polynomial_t': None,
'special.chebyshev_polynomial_u': None,
'special.entr': None,
'special.erfcx': None,
'special.hermite_polynomial_h': None,
'special.hermite_polynomial_he': None,
@ -716,6 +716,9 @@ def mps_ops_modifier(ops):
'special.xlog1py': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'special.zeta': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# entr does not support boolean types
'special.entr': [torch.bool],
# GEMM on MPS is not supported for integral types
'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],