[MPS] Fix ICE for entr bool instantiation on M1/M2 (#152204)

By instantiating it implicitly, otherwise attempts to run something like
```
% python3 -c "import torch; print(torch.special.entr(torch.testing.make_tensor(10, dtype=torch.bool, device='mps')))"
```
will fail with
```
Failed to created pipeline state object, error: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error"
```

Similar in spirit to https://github.com/pytorch/pytorch/pull/149123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152204
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-04-25 11:25:39 -07:00 committed by PyTorch MergeBot
parent d7eb3a492c
commit 56190d2577
2 changed files with 15 additions and 4 deletions

View File

@ -16,10 +16,9 @@ DEFINE_UNARY_FLOATING_FUNCTOR(i0e);
DEFINE_UNARY_FLOATING_FUNCTOR(i1);
DEFINE_UNARY_FLOATING_FUNCTOR(i1e);
DEFINE_UNARY_FLOATING_FUNCTOR(spherical_bessel_j0);
DEFINE_UNARY_FLOATING_FUNCTOR(entr);
// TODO: Replaceme with DEFINE_UNARY_FLOATING_FUNCTOR
// But for some reason instantinating bessel_y[01] on M1/M2 results in
// But for some reason instantinating bessel_y[01] and entr on M1/M2 results in
// Failed to created pipeline state object, error: Error Domain=AGXMetalG14X
// Code=3 "Compiler encountered an internal error"
struct bessel_y0_forward_functor {
@ -50,6 +49,20 @@ struct bessel_y1_forward_functor {
}
};
struct entr_functor {
template <typename T>
inline enable_if_t<is_floating_point_v<T>, T> operator()(const T x) {
return static_cast<T>(entr(x));
}
template <typename T>
inline enable_if_t<is_integral_v<T>, float> operator()(const T x) {
return entr(static_cast<float>(x));
}
inline float operator()(const bool x) {
return x ? -0.0 : 0.0;
}
};
#define REGISTER_SPECIAL(DTI, DTO) \
REGISTER_UNARY_OP(bessel_j0_forward, DTI, DTO); \
REGISTER_UNARY_OP(bessel_j1_forward, DTI, DTO); \

View File

@ -550,8 +550,6 @@ if torch.backends.mps.is_available():
torch.uint8,
torch.int8,
],
# entr does not support boolean types
"special.entr": [torch.bool],
# GEMM on MPS is not supported for integral types
"nn.functional.linear": [
torch.int16,