mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Fix ICE for entr bool instantiation on M1/M2 (#152204)
By instantiating it implicitly, otherwise attempts to run something like ``` % python3 -c "import torch; print(torch.special.entr(torch.testing.make_tensor(10, dtype=torch.bool, device='mps')))" ``` will fail with ``` Failed to created pipeline state object, error: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" ``` Similar in spirit to https://github.com/pytorch/pytorch/pull/149123 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152204 Approved by: https://github.com/dcci
This commit is contained in:
parent
d7eb3a492c
commit
56190d2577
|
|
@ -16,10 +16,9 @@ DEFINE_UNARY_FLOATING_FUNCTOR(i0e);
|
|||
DEFINE_UNARY_FLOATING_FUNCTOR(i1);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(i1e);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(spherical_bessel_j0);
|
||||
DEFINE_UNARY_FLOATING_FUNCTOR(entr);
|
||||
|
||||
// TODO: Replaceme with DEFINE_UNARY_FLOATING_FUNCTOR
|
||||
// But for some reason instantinating bessel_y[01] on M1/M2 results in
|
||||
// But for some reason instantinating bessel_y[01] and entr on M1/M2 results in
|
||||
// Failed to created pipeline state object, error: Error Domain=AGXMetalG14X
|
||||
// Code=3 "Compiler encountered an internal error"
|
||||
struct bessel_y0_forward_functor {
|
||||
|
|
@ -50,6 +49,20 @@ struct bessel_y1_forward_functor {
|
|||
}
|
||||
};
|
||||
|
||||
struct entr_functor {
|
||||
template <typename T>
|
||||
inline enable_if_t<is_floating_point_v<T>, T> operator()(const T x) {
|
||||
return static_cast<T>(entr(x));
|
||||
}
|
||||
template <typename T>
|
||||
inline enable_if_t<is_integral_v<T>, float> operator()(const T x) {
|
||||
return entr(static_cast<float>(x));
|
||||
}
|
||||
inline float operator()(const bool x) {
|
||||
return x ? -0.0 : 0.0;
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_SPECIAL(DTI, DTO) \
|
||||
REGISTER_UNARY_OP(bessel_j0_forward, DTI, DTO); \
|
||||
REGISTER_UNARY_OP(bessel_j1_forward, DTI, DTO); \
|
||||
|
|
|
|||
|
|
@ -550,8 +550,6 @@ if torch.backends.mps.is_available():
|
|||
torch.uint8,
|
||||
torch.int8,
|
||||
],
|
||||
# entr does not support boolean types
|
||||
"special.entr": [torch.bool],
|
||||
# GEMM on MPS is not supported for integral types
|
||||
"nn.functional.linear": [
|
||||
torch.int16,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user