Move toUnderlying to headeronly (#165694)

As in the title. Required in upper PRs of this ghstack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165694
Approved by: https://github.com/janeyx99
This commit is contained in:
Pearu Peterson 2025-10-19 12:54:52 +03:00 committed by PyTorch MergeBot
parent 4fae6968b1
commit d01f15152c
4 changed files with 34 additions and 16 deletions

View File

@ -137,22 +137,6 @@ inline ScalarType toQIntType(ScalarType t) {
} }
} }
inline ScalarType toUnderlying(ScalarType t) {
switch (t) {
case ScalarType::QUInt8:
case ScalarType::QUInt4x2:
[[fallthrough]];
case ScalarType::QUInt2x4:
return ScalarType::Byte;
case ScalarType::QInt8:
return ScalarType::Char;
case ScalarType::QInt32:
return ScalarType::Int;
default:
return t;
}
}
inline bool isSignedType(ScalarType t) { inline bool isSignedType(ScalarType t) {
#define CASE_ISSIGNED(name) \ #define CASE_ISSIGNED(name) \
case ScalarType::name: \ case ScalarType::name: \

View File

@ -74,3 +74,19 @@ TEST(TestScalarType, operator_left_shift) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK); AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK #undef DEFINE_CHECK
} }
TEST(TestScalarType, toUnderlying) {
using torch::headeronly::ScalarType;
using torch::headeronly::toUnderlying;
EXPECT_EQ(toUnderlying(ScalarType::QUInt8), ScalarType::Byte);
EXPECT_EQ(toUnderlying(ScalarType::QUInt4x2), ScalarType::Byte);
EXPECT_EQ(toUnderlying(ScalarType::QUInt2x4), ScalarType::Byte);
EXPECT_EQ(toUnderlying(ScalarType::QInt8), ScalarType::Char);
EXPECT_EQ(toUnderlying(ScalarType::QInt32), ScalarType::Int);
#define DEFINE_CHECK(_, name) \
EXPECT_EQ(toUnderlying(ScalarType::name), ScalarType::name);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK);
AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK);
#undef DEFINE_CHECK
}

View File

@ -135,3 +135,4 @@ AT_FORALL_FLOAT8_TYPES
AT_FORALL_COMPLEX_TYPES AT_FORALL_COMPLEX_TYPES
toString toString
<< <<
toUnderlying

View File

@ -318,6 +318,22 @@ inline std::ostream& operator<<(
return stream << toString(scalar_type); return stream << toString(scalar_type);
} }
inline ScalarType toUnderlying(ScalarType t) {
switch (t) {
case ScalarType::QUInt8:
case ScalarType::QUInt4x2:
[[fallthrough]];
case ScalarType::QUInt2x4:
return ScalarType::Byte;
case ScalarType::QInt8:
return ScalarType::Char;
case ScalarType::QInt32:
return ScalarType::Int;
default:
return t;
}
}
} // namespace c10 } // namespace c10
namespace torch::headeronly { namespace torch::headeronly {
@ -330,4 +346,5 @@ using c10::impl::ScalarTypeToCPPTypeT;
} // namespace impl } // namespace impl
using c10::toString; using c10::toString;
using c10::operator<<; using c10::operator<<;
using c10::toUnderlying;
} // namespace torch::headeronly } // namespace torch::headeronly