diff --git a/c10/metal/utils.h b/c10/metal/utils.h index 04a09fb77c4..4318077a7de 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -128,5 +128,22 @@ using vec4type_t = typename detail::vectypes::type4; template using opmath_t = typename detail::OpMathType::type; + +// TODO: Move it to type_traits header may be +template +using result_of = decltype(::metal::declval()(::metal::declval()...)); + +template +constexpr constant bool is_complex_v = + ::metal::is_same_v || ::metal::is_same_v; + +template +constexpr constant bool is_scalar_floating_point_v = + ::metal::is_floating_point_v && ::metal::is_scalar_v; + +template +constexpr constant bool is_scalar_integral_v = + ::metal::is_integral_v && ::metal::is_scalar_v; + } // namespace metal } // namespace c10