mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[MPS] Add support for entr() in eager. (#147948)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147948 Approved by: https://github.com/malfet
This commit is contained in:
parent
eb08ada5d3
commit
683e083e8d
|
|
@ -25,13 +25,23 @@ void kernel spherical_bessel_j0(
|
|||
c10::metal::spherical_bessel_j0(static_cast<T0>(input[index]));
|
||||
}
|
||||
|
||||
#define REGISTER_I0_I1(DTI, DTO) \
|
||||
template [[host_name("i0_" #DTO "_" #DTI)]] void kernel i0<DTO, DTI>( \
|
||||
device DTO*, constant DTI*, uint); \
|
||||
template [[host_name("i1_" #DTO "_" #DTI)]] void kernel i1<DTO, DTI>( \
|
||||
device DTO*, constant DTI*, uint); \
|
||||
template [[host_name("spherical_bessel_j0_" #DTO "_" #DTI)]] void kernel \
|
||||
spherical_bessel_j0<DTO, DTI>(device DTO*, constant DTI*, uint);
|
||||
template <typename T0, typename T1>
|
||||
void kernel entr(
|
||||
device T0* output,
|
||||
constant T1* input,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
output[index] = c10::metal::entr(static_cast<T0>(input[index]));
|
||||
}
|
||||
|
||||
#define REGISTER_I0_I1(DTI, DTO) \
|
||||
template [[host_name("i0_" #DTO "_" #DTI)]] void kernel i0<DTO, DTI>( \
|
||||
device DTO*, constant DTI*, uint); \
|
||||
template [[host_name("i1_" #DTO "_" #DTI)]] void kernel i1<DTO, DTI>( \
|
||||
device DTO*, constant DTI*, uint); \
|
||||
template [[host_name("spherical_bessel_j0_" #DTO "_" #DTI)]] void kernel \
|
||||
spherical_bessel_j0<DTO, DTI>(device DTO*, constant DTI*, uint); \
|
||||
template [[host_name("entr_" #DTO "_" #DTI)]] void kernel entr<DTO, DTI>( \
|
||||
device DTO*, constant DTI*, uint);
|
||||
|
||||
REGISTER_I0_I1(float, float);
|
||||
REGISTER_I0_I1(bool, float);
|
||||
|
|
|
|||
|
|
@ -24,7 +24,12 @@ static void spherical_bessel_j0_kernel_mps(TensorIteratorBase& iter) {
|
|||
lib.exec_unary_kernel(iter, "spherical_bessel_j0");
|
||||
}
|
||||
|
||||
static void entr_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "entr");
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(i0_stub, &i0_kernel_mps)
|
||||
REGISTER_DISPATCH(special_i1_stub, &i1_kernel_mps)
|
||||
REGISTER_DISPATCH(special_spherical_bessel_j0_stub, &spherical_bessel_j0_kernel_mps)
|
||||
REGISTER_DISPATCH(special_entr_stub, &entr_kernel_mps)
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -13230,7 +13230,7 @@
|
|||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: special_entr_out
|
||||
CPU, CUDA, MPS: special_entr_out
|
||||
tags: pointwise
|
||||
|
||||
- func: special_ndtri(Tensor self) -> Tensor
|
||||
|
|
|
|||
|
|
@ -531,20 +531,20 @@ inline float xlog1py(T x, T y) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
inline float entr(T a) {
|
||||
inline T entr(T a) {
|
||||
if (a != a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
if (a > 0) {
|
||||
return -a * ::metal::log(a);
|
||||
return static_cast<T>(-a * ::metal::log(a));
|
||||
}
|
||||
|
||||
if (a == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return -INFINITY;
|
||||
return static_cast<T>(-INFINITY);
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
|
|
|||
|
|
@ -325,6 +325,7 @@ def mps_ops_modifier(ops):
|
|||
'sinc',
|
||||
'slice',
|
||||
'special.spherical_bessel_j0',
|
||||
'special.entr',
|
||||
'special.xlog1py',
|
||||
'special.zeta',
|
||||
'split',
|
||||
|
|
@ -649,7 +650,6 @@ def mps_ops_modifier(ops):
|
|||
'special.bessel_y1': None,
|
||||
'special.chebyshev_polynomial_t': None,
|
||||
'special.chebyshev_polynomial_u': None,
|
||||
'special.entr': None,
|
||||
'special.erfcx': None,
|
||||
'special.hermite_polynomial_h': None,
|
||||
'special.hermite_polynomial_he': None,
|
||||
|
|
@ -716,6 +716,9 @@ def mps_ops_modifier(ops):
|
|||
'special.xlog1py': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
'special.zeta': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
|
||||
# entr does not support boolean types
|
||||
'special.entr': [torch.bool],
|
||||
|
||||
# GEMM on MPS is not supported for integral types
|
||||
'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user