[MPS] Add support for hermite_polynomial_he (inductor/eager). (#151754)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151754
Approved by: https://github.com/malfet, https://github.com/jansel
This commit is contained in:
Davide Italiano 2025-04-20 17:44:40 +00:00 committed by PyTorch MergeBot
parent c3a7278278
commit 470132c6a1
8 changed files with 55 additions and 4 deletions

View File

@ -82,6 +82,13 @@ struct hermite_polynomial_h_functor {
}
};
struct hermite_polynomial_he_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return static_cast<T>(c10::metal::hermite_polynomial_he_forward(a, b));
}
};
struct nextafter_functor {
#if __METAL_VERSION__ < 310
template <typename U>
@ -173,6 +180,8 @@ REGISTER_BINARY_OP(chebyshev_polynomial_w, float, float);
REGISTER_BINARY_OP(chebyshev_polynomial_w, half, half);
REGISTER_BINARY_OP(hermite_polynomial_h, float, float);
REGISTER_BINARY_OP(hermite_polynomial_h, half, half);
REGISTER_BINARY_OP(hermite_polynomial_he, float, float);
REGISTER_BINARY_OP(hermite_polynomial_he, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_OP(copysign, bfloat, bfloat);
@ -186,6 +195,7 @@ REGISTER_BINARY_OP(chebyshev_polynomial_u, bfloat, bfloat);
REGISTER_BINARY_OP(chebyshev_polynomial_v, bfloat, bfloat);
REGISTER_BINARY_OP(chebyshev_polynomial_w, bfloat, bfloat);
REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat);
REGISTER_BINARY_OP(hermite_polynomial_he, bfloat, bfloat);
#endif
// Complex binary functions

View File

@ -116,6 +116,12 @@ static void hermite_polynomial_h_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "hermite_polynomial_h");
}
static void hermite_polynomial_he_mps_kernel(TensorIteratorBase& iter) {
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()),
"hermite_polynomial_he_mps not implemented for non-floating types");
lib.exec_binary_kernel(iter, "hermite_polynomial_he");
}
static void polar_mps_kernel(TensorIterator& iter) {
lib.exec_binary_kernel(iter, "polar");
}
@ -135,6 +141,7 @@ REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kerne
REGISTER_DISPATCH(chebyshev_polynomial_v_stub, &chebyshev_polynomial_v_mps_kernel)
REGISTER_DISPATCH(chebyshev_polynomial_w_stub, &chebyshev_polynomial_w_mps_kernel)
REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_mps_kernel)
REGISTER_DISPATCH(hermite_polynomial_he_stub, &hermite_polynomial_he_mps_kernel)
REGISTER_DISPATCH(polar_stub, &polar_mps_kernel);
REGISTER_DISPATCH(complex_stub, &complex_mps_kernel);
} // namespace at::native

View File

@ -15363,7 +15363,7 @@
- func: special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
dispatch:
CPU, CUDA: special_hermite_polynomial_he_out
CPU, CUDA, MPS: special_hermite_polynomial_he_out
python_module: special
structured_inherits: TensorIteratorBase
structured: True

View File

@ -1753,5 +1753,36 @@ inline float hermite_polynomial_h_forward(T x, int64_t n) {
return r;
} // hermite_polynomial_h_forward(T x, int64_t n)
template <typename T>
inline float hermite_polynomial_he_forward(T x, int64_t n) {
if (n < 0) {
return 0.0;
}
if (n == 0) {
return 1.0;
}
if (n == 1) {
return x;
}
if (n > getHermitianLimit<T>()) {
return NAN;
}
float p = 1.0;
float q = x;
float r;
for (int64_t k = 1; k < n; k++) {
r = x * q - k * p;
p = q;
q = r;
}
return r;
} // hermite_polynomial_he_forward(T x, int64_t n)
} // namespace metal
} // namespace c10

View File

@ -132,7 +132,7 @@ class MPSBasicTests(TestCase):
"chebyshev_polynomial_u",
"chebyshev_polynomial_v",
"chebyshev_polynomial_w",
"hermite_polynomial_h",
"hermite_polynomial_he",
],
)
def test_pointwise_binary_op(self, op_name):

View File

@ -12525,7 +12525,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
"erfcx",
"gammainc",
"gammaincc",
"hermite_polynomial_he",
"laguerre_polynomial_l",
"legendre_polynomial_p",
"log_ndtr",

View File

@ -656,7 +656,6 @@ def mps_ops_modifier(ops):
'sparse.mmreduce': None,
'special.airy_ai': None,
'special.erfcx': None,
'special.hermite_polynomial_he': None,
'special.laguerre_polynomial_l': None,
'special.log_ndtr': None,
'special.ndtri': None,
@ -713,6 +712,7 @@ def mps_ops_modifier(ops):
'special.chebyshev_polynomial_t': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'special.chebyshev_polynomial_u': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'special.hermite_polynomial_h': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'special.hermite_polynomial_he': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# entr does not support boolean types
'special.entr': [torch.bool],

View File

@ -477,6 +477,10 @@ class MetalOverrides(OpOverrides):
def hermite_polynomial_h(x: CSEVariable, n: CSEVariable) -> str:
return f"c10::metal::hermite_polynomial_h_forward({x}, {n})"
@staticmethod
def hermite_polynomial_he(x: CSEVariable, n: CSEVariable) -> str:
return f"c10::metal::hermite_polynomial_he_forward({x}, {n})"
MetalOverrides._initialize_pointwise_overrides("mps")