mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move toUnderlying to headeronly (#165694)
As in the title. Required in upper PRs of this ghstack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165694 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
4fae6968b1
commit
d01f15152c
|
|
@ -137,22 +137,6 @@ inline ScalarType toQIntType(ScalarType t) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline ScalarType toUnderlying(ScalarType t) {
|
|
||||||
switch (t) {
|
|
||||||
case ScalarType::QUInt8:
|
|
||||||
case ScalarType::QUInt4x2:
|
|
||||||
[[fallthrough]];
|
|
||||||
case ScalarType::QUInt2x4:
|
|
||||||
return ScalarType::Byte;
|
|
||||||
case ScalarType::QInt8:
|
|
||||||
return ScalarType::Char;
|
|
||||||
case ScalarType::QInt32:
|
|
||||||
return ScalarType::Int;
|
|
||||||
default:
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool isSignedType(ScalarType t) {
|
inline bool isSignedType(ScalarType t) {
|
||||||
#define CASE_ISSIGNED(name) \
|
#define CASE_ISSIGNED(name) \
|
||||||
case ScalarType::name: \
|
case ScalarType::name: \
|
||||||
|
|
|
||||||
|
|
@ -74,3 +74,19 @@ TEST(TestScalarType, operator_left_shift) {
|
||||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||||
#undef DEFINE_CHECK
|
#undef DEFINE_CHECK
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TestScalarType, toUnderlying) {
|
||||||
|
using torch::headeronly::ScalarType;
|
||||||
|
using torch::headeronly::toUnderlying;
|
||||||
|
|
||||||
|
EXPECT_EQ(toUnderlying(ScalarType::QUInt8), ScalarType::Byte);
|
||||||
|
EXPECT_EQ(toUnderlying(ScalarType::QUInt4x2), ScalarType::Byte);
|
||||||
|
EXPECT_EQ(toUnderlying(ScalarType::QUInt2x4), ScalarType::Byte);
|
||||||
|
EXPECT_EQ(toUnderlying(ScalarType::QInt8), ScalarType::Char);
|
||||||
|
EXPECT_EQ(toUnderlying(ScalarType::QInt32), ScalarType::Int);
|
||||||
|
#define DEFINE_CHECK(_, name) \
|
||||||
|
EXPECT_EQ(toUnderlying(ScalarType::name), ScalarType::name);
|
||||||
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK);
|
||||||
|
AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK);
|
||||||
|
#undef DEFINE_CHECK
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -135,3 +135,4 @@ AT_FORALL_FLOAT8_TYPES
|
||||||
AT_FORALL_COMPLEX_TYPES
|
AT_FORALL_COMPLEX_TYPES
|
||||||
toString
|
toString
|
||||||
<<
|
<<
|
||||||
|
toUnderlying
|
||||||
|
|
|
||||||
|
|
@ -318,6 +318,22 @@ inline std::ostream& operator<<(
|
||||||
return stream << toString(scalar_type);
|
return stream << toString(scalar_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline ScalarType toUnderlying(ScalarType t) {
|
||||||
|
switch (t) {
|
||||||
|
case ScalarType::QUInt8:
|
||||||
|
case ScalarType::QUInt4x2:
|
||||||
|
[[fallthrough]];
|
||||||
|
case ScalarType::QUInt2x4:
|
||||||
|
return ScalarType::Byte;
|
||||||
|
case ScalarType::QInt8:
|
||||||
|
return ScalarType::Char;
|
||||||
|
case ScalarType::QInt32:
|
||||||
|
return ScalarType::Int;
|
||||||
|
default:
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
||||||
namespace torch::headeronly {
|
namespace torch::headeronly {
|
||||||
|
|
@ -330,4 +346,5 @@ using c10::impl::ScalarTypeToCPPTypeT;
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
using c10::toString;
|
using c10::toString;
|
||||||
using c10::operator<<;
|
using c10::operator<<;
|
||||||
|
using c10::toUnderlying;
|
||||||
} // namespace torch::headeronly
|
} // namespace torch::headeronly
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user