mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move toUnderlying to headeronly (#165694)
As in the title. Required in upper PRs of this ghstack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165694 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
4fae6968b1
commit
d01f15152c
|
|
@ -137,22 +137,6 @@ inline ScalarType toQIntType(ScalarType t) {
|
|||
}
|
||||
}
|
||||
|
||||
inline ScalarType toUnderlying(ScalarType t) {
|
||||
switch (t) {
|
||||
case ScalarType::QUInt8:
|
||||
case ScalarType::QUInt4x2:
|
||||
[[fallthrough]];
|
||||
case ScalarType::QUInt2x4:
|
||||
return ScalarType::Byte;
|
||||
case ScalarType::QInt8:
|
||||
return ScalarType::Char;
|
||||
case ScalarType::QInt32:
|
||||
return ScalarType::Int;
|
||||
default:
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool isSignedType(ScalarType t) {
|
||||
#define CASE_ISSIGNED(name) \
|
||||
case ScalarType::name: \
|
||||
|
|
|
|||
|
|
@ -74,3 +74,19 @@ TEST(TestScalarType, operator_left_shift) {
|
|||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, toUnderlying) {
|
||||
using torch::headeronly::ScalarType;
|
||||
using torch::headeronly::toUnderlying;
|
||||
|
||||
EXPECT_EQ(toUnderlying(ScalarType::QUInt8), ScalarType::Byte);
|
||||
EXPECT_EQ(toUnderlying(ScalarType::QUInt4x2), ScalarType::Byte);
|
||||
EXPECT_EQ(toUnderlying(ScalarType::QUInt2x4), ScalarType::Byte);
|
||||
EXPECT_EQ(toUnderlying(ScalarType::QInt8), ScalarType::Char);
|
||||
EXPECT_EQ(toUnderlying(ScalarType::QInt32), ScalarType::Int);
|
||||
#define DEFINE_CHECK(_, name) \
|
||||
EXPECT_EQ(toUnderlying(ScalarType::name), ScalarType::name);
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK);
|
||||
AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,3 +135,4 @@ AT_FORALL_FLOAT8_TYPES
|
|||
AT_FORALL_COMPLEX_TYPES
|
||||
toString
|
||||
<<
|
||||
toUnderlying
|
||||
|
|
|
|||
|
|
@ -318,6 +318,22 @@ inline std::ostream& operator<<(
|
|||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
inline ScalarType toUnderlying(ScalarType t) {
|
||||
switch (t) {
|
||||
case ScalarType::QUInt8:
|
||||
case ScalarType::QUInt4x2:
|
||||
[[fallthrough]];
|
||||
case ScalarType::QUInt2x4:
|
||||
return ScalarType::Byte;
|
||||
case ScalarType::QInt8:
|
||||
return ScalarType::Char;
|
||||
case ScalarType::QInt32:
|
||||
return ScalarType::Int;
|
||||
default:
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::headeronly {
|
||||
|
|
@ -330,4 +346,5 @@ using c10::impl::ScalarTypeToCPPTypeT;
|
|||
} // namespace impl
|
||||
using c10::toString;
|
||||
using c10::operator<<;
|
||||
using c10::toUnderlying;
|
||||
} // namespace torch::headeronly
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user