mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Add support for hermite_polynomial_he (inductor/eager). (#151754)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151754 Approved by: https://github.com/malfet, https://github.com/jansel
This commit is contained in:
parent
c3a7278278
commit
470132c6a1
|
|
@ -82,6 +82,13 @@ struct hermite_polynomial_h_functor {
|
|||
}
|
||||
};
|
||||
|
||||
struct hermite_polynomial_he_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return static_cast<T>(c10::metal::hermite_polynomial_he_forward(a, b));
|
||||
}
|
||||
};
|
||||
|
||||
struct nextafter_functor {
|
||||
#if __METAL_VERSION__ < 310
|
||||
template <typename U>
|
||||
|
|
@ -173,6 +180,8 @@ REGISTER_BINARY_OP(chebyshev_polynomial_w, float, float);
|
|||
REGISTER_BINARY_OP(chebyshev_polynomial_w, half, half);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_h, float, float);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_h, half, half);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_he, float, float);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_he, half, half);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_OP(copysign, bfloat, bfloat);
|
||||
|
|
@ -186,6 +195,7 @@ REGISTER_BINARY_OP(chebyshev_polynomial_u, bfloat, bfloat);
|
|||
REGISTER_BINARY_OP(chebyshev_polynomial_v, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(chebyshev_polynomial_w, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_he, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
// Complex binary functions
|
||||
|
|
|
|||
|
|
@ -116,6 +116,12 @@ static void hermite_polynomial_h_mps_kernel(TensorIteratorBase& iter) {
|
|||
lib.exec_binary_kernel(iter, "hermite_polynomial_h");
|
||||
}
|
||||
|
||||
static void hermite_polynomial_he_mps_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()),
|
||||
"hermite_polynomial_he_mps not implemented for non-floating types");
|
||||
lib.exec_binary_kernel(iter, "hermite_polynomial_he");
|
||||
}
|
||||
|
||||
static void polar_mps_kernel(TensorIterator& iter) {
|
||||
lib.exec_binary_kernel(iter, "polar");
|
||||
}
|
||||
|
|
@ -135,6 +141,7 @@ REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kerne
|
|||
REGISTER_DISPATCH(chebyshev_polynomial_v_stub, &chebyshev_polynomial_v_mps_kernel)
|
||||
REGISTER_DISPATCH(chebyshev_polynomial_w_stub, &chebyshev_polynomial_w_mps_kernel)
|
||||
REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_mps_kernel)
|
||||
REGISTER_DISPATCH(hermite_polynomial_he_stub, &hermite_polynomial_he_mps_kernel)
|
||||
REGISTER_DISPATCH(polar_stub, &polar_mps_kernel);
|
||||
REGISTER_DISPATCH(complex_stub, &complex_mps_kernel);
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -15363,7 +15363,7 @@
|
|||
- func: special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck
|
||||
dispatch:
|
||||
CPU, CUDA: special_hermite_polynomial_he_out
|
||||
CPU, CUDA, MPS: special_hermite_polynomial_he_out
|
||||
python_module: special
|
||||
structured_inherits: TensorIteratorBase
|
||||
structured: True
|
||||
|
|
|
|||
|
|
@ -1753,5 +1753,36 @@ inline float hermite_polynomial_h_forward(T x, int64_t n) {
|
|||
return r;
|
||||
} // hermite_polynomial_h_forward(T x, int64_t n)
|
||||
|
||||
template <typename T>
|
||||
inline float hermite_polynomial_he_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
if (n == 0) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
if (n == 1) {
|
||||
return x;
|
||||
}
|
||||
|
||||
if (n > getHermitianLimit<T>()) {
|
||||
return NAN;
|
||||
}
|
||||
|
||||
float p = 1.0;
|
||||
float q = x;
|
||||
float r;
|
||||
|
||||
for (int64_t k = 1; k < n; k++) {
|
||||
r = x * q - k * p;
|
||||
p = q;
|
||||
q = r;
|
||||
}
|
||||
|
||||
return r;
|
||||
} // hermite_polynomial_he_forward(T x, int64_t n)
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ class MPSBasicTests(TestCase):
|
|||
"chebyshev_polynomial_u",
|
||||
"chebyshev_polynomial_v",
|
||||
"chebyshev_polynomial_w",
|
||||
"hermite_polynomial_h",
|
||||
"hermite_polynomial_he",
|
||||
],
|
||||
)
|
||||
def test_pointwise_binary_op(self, op_name):
|
||||
|
|
|
|||
|
|
@ -12525,7 +12525,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
"erfcx",
|
||||
"gammainc",
|
||||
"gammaincc",
|
||||
"hermite_polynomial_he",
|
||||
"laguerre_polynomial_l",
|
||||
"legendre_polynomial_p",
|
||||
"log_ndtr",
|
||||
|
|
|
|||
|
|
@ -656,7 +656,6 @@ def mps_ops_modifier(ops):
|
|||
'sparse.mmreduce': None,
|
||||
'special.airy_ai': None,
|
||||
'special.erfcx': None,
|
||||
'special.hermite_polynomial_he': None,
|
||||
'special.laguerre_polynomial_l': None,
|
||||
'special.log_ndtr': None,
|
||||
'special.ndtri': None,
|
||||
|
|
@ -713,6 +712,7 @@ def mps_ops_modifier(ops):
|
|||
'special.chebyshev_polynomial_t': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
'special.chebyshev_polynomial_u': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
'special.hermite_polynomial_h': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
'special.hermite_polynomial_he': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
|
||||
# entr does not support boolean types
|
||||
'special.entr': [torch.bool],
|
||||
|
|
|
|||
|
|
@ -477,6 +477,10 @@ class MetalOverrides(OpOverrides):
|
|||
def hermite_polynomial_h(x: CSEVariable, n: CSEVariable) -> str:
|
||||
return f"c10::metal::hermite_polynomial_h_forward({x}, {n})"
|
||||
|
||||
@staticmethod
|
||||
def hermite_polynomial_he(x: CSEVariable, n: CSEVariable) -> str:
|
||||
return f"c10::metal::hermite_polynomial_he_forward({x}, {n})"
|
||||
|
||||
|
||||
MetalOverrides._initialize_pointwise_overrides("mps")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user