From e8900fbe4f324173f602de816f792a6e5bf2f729 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 4 Mar 2025 07:46:39 -0800 Subject: [PATCH] [MPS] Add some useful utils (#148448) Like `is_compex_v`, `is_scalar_intergral_v`, `result_of` etc Pull Request resolved: https://github.com/pytorch/pytorch/pull/148448 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #148398, #148399 --- c10/metal/utils.h | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/c10/metal/utils.h b/c10/metal/utils.h index 04a09fb77c4..4318077a7de 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -128,5 +128,22 @@ using vec4type_t = typename detail::vectypes::type4; template using opmath_t = typename detail::OpMathType::type; + +// TODO: Move it to type_traits header may be +template +using result_of = decltype(::metal::declval()(::metal::declval()...)); + +template +constexpr constant bool is_complex_v = + ::metal::is_same_v || ::metal::is_same_v; + +template +constexpr constant bool is_scalar_floating_point_v = + ::metal::is_floating_point_v && ::metal::is_scalar_v; + +template +constexpr constant bool is_scalar_integral_v = + ::metal::is_integral_v && ::metal::is_scalar_v; + } // namespace metal } // namespace c10