mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
This PR is created to replace the reverted PR https://github.com/pytorch/pytorch/pull/164405 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166018 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
f9953e0f61
commit
4fae6968b1
|
|
@ -52,19 +52,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
|||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
|
||||
#undef DEFINE_CONSTANT
|
||||
|
||||
inline const char* toString(ScalarType t) {
|
||||
#define DEFINE_CASE(_, name) \
|
||||
case ScalarType::name: \
|
||||
return #name;
|
||||
|
||||
switch (t) {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
||||
default:
|
||||
return "UNKNOWN_SCALAR";
|
||||
}
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
|
||||
inline size_t elementSize(ScalarType t) {
|
||||
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
|
||||
case ScalarType::name: \
|
||||
|
|
@ -308,12 +295,6 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
|
|||
|
||||
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
at::ScalarType scalar_type) {
|
||||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
// Returns a pair of strings representing the names for each dtype.
|
||||
// The returned pair is (name, legacy_name_if_applicable)
|
||||
C10_API std::pair<std::string, std::string> getDtypeNames(
|
||||
|
|
|
|||
|
|
@ -53,3 +53,24 @@ TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
|
|||
|
||||
#undef DEFINE_CHECK
|
||||
#undef TEST_FORALL
|
||||
|
||||
TEST(TestScalarType, toString) {
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name);
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, operator_left_shift) {
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(_, name) \
|
||||
{ \
|
||||
std::stringstream ss; \
|
||||
ss << ScalarType::name; \
|
||||
EXPECT_EQ(ss.str(), #name); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,3 +133,5 @@ AT_FORALL_SCALAR_TYPES_AND7
|
|||
AT_FORALL_QINT_TYPES
|
||||
AT_FORALL_FLOAT8_TYPES
|
||||
AT_FORALL_COMPLEX_TYPES
|
||||
toString
|
||||
<<
|
||||
|
|
|
|||
|
|
@ -299,6 +299,25 @@ using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
|
|||
|
||||
} // namespace impl
|
||||
|
||||
inline const char* toString(ScalarType t) {
|
||||
#define DEFINE_CASE(_, name) \
|
||||
case ScalarType::name: \
|
||||
return #name;
|
||||
|
||||
switch (t) {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
||||
default:
|
||||
return "UNKNOWN_SCALAR";
|
||||
}
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
c10::ScalarType scalar_type) {
|
||||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::headeronly {
|
||||
|
|
@ -309,4 +328,6 @@ using c10::ScalarType;
|
|||
namespace impl {
|
||||
using c10::impl::ScalarTypeToCPPTypeT;
|
||||
} // namespace impl
|
||||
using c10::toString;
|
||||
using c10::operator<<;
|
||||
} // namespace torch::headeronly
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user