From d01f15152cdf9a4b693d5c768cef31a0b2a5b012 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Sun, 19 Oct 2025 12:54:52 +0300 Subject: [PATCH] Move toUnderlying to headeronly (#165694) As in the title. Required in upper PRs of this ghstack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165694 Approved by: https://github.com/janeyx99 --- c10/core/ScalarType.h | 16 ---------------- test/cpp/aoti_abi_check/test_scalartype.cpp | 16 ++++++++++++++++ torch/header_only_apis.txt | 1 + torch/headeronly/core/ScalarType.h | 17 +++++++++++++++++ 4 files changed, 34 insertions(+), 16 deletions(-) diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 24396630417..e0c84370e87 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -137,22 +137,6 @@ inline ScalarType toQIntType(ScalarType t) { } } -inline ScalarType toUnderlying(ScalarType t) { - switch (t) { - case ScalarType::QUInt8: - case ScalarType::QUInt4x2: - [[fallthrough]]; - case ScalarType::QUInt2x4: - return ScalarType::Byte; - case ScalarType::QInt8: - return ScalarType::Char; - case ScalarType::QInt32: - return ScalarType::Int; - default: - return t; - } -} - inline bool isSignedType(ScalarType t) { #define CASE_ISSIGNED(name) \ case ScalarType::name: \ diff --git a/test/cpp/aoti_abi_check/test_scalartype.cpp b/test/cpp/aoti_abi_check/test_scalartype.cpp index 13d1b98a770..e0952a48e5a 100644 --- a/test/cpp/aoti_abi_check/test_scalartype.cpp +++ b/test/cpp/aoti_abi_check/test_scalartype.cpp @@ -74,3 +74,19 @@ TEST(TestScalarType, operator_left_shift) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK); #undef DEFINE_CHECK } + +TEST(TestScalarType, toUnderlying) { + using torch::headeronly::ScalarType; + using torch::headeronly::toUnderlying; + + EXPECT_EQ(toUnderlying(ScalarType::QUInt8), ScalarType::Byte); + EXPECT_EQ(toUnderlying(ScalarType::QUInt4x2), ScalarType::Byte); + EXPECT_EQ(toUnderlying(ScalarType::QUInt2x4), ScalarType::Byte); + EXPECT_EQ(toUnderlying(ScalarType::QInt8), ScalarType::Char); + EXPECT_EQ(toUnderlying(ScalarType::QInt32), ScalarType::Int); +#define DEFINE_CHECK(_, name) \ + EXPECT_EQ(toUnderlying(ScalarType::name), ScalarType::name); + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK); + AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK); +#undef DEFINE_CHECK +} diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 8fe36f78063..70165a7493e 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -135,3 +135,4 @@ AT_FORALL_FLOAT8_TYPES AT_FORALL_COMPLEX_TYPES toString << +toUnderlying diff --git a/torch/headeronly/core/ScalarType.h b/torch/headeronly/core/ScalarType.h index e1451e9cbb2..624792c88d4 100644 --- a/torch/headeronly/core/ScalarType.h +++ b/torch/headeronly/core/ScalarType.h @@ -318,6 +318,22 @@ inline std::ostream& operator<<( return stream << toString(scalar_type); } +inline ScalarType toUnderlying(ScalarType t) { + switch (t) { + case ScalarType::QUInt8: + case ScalarType::QUInt4x2: + [[fallthrough]]; + case ScalarType::QUInt2x4: + return ScalarType::Byte; + case ScalarType::QInt8: + return ScalarType::Char; + case ScalarType::QInt32: + return ScalarType::Int; + default: + return t; + } +} + } // namespace c10 namespace torch::headeronly { @@ -330,4 +346,5 @@ using c10::impl::ScalarTypeToCPPTypeT; } // namespace impl using c10::toString; using c10::operator<<; +using c10::toUnderlying; } // namespace torch::headeronly