mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move AT_FORALL_... macros and ScalarTypeToCPPTypeT to headeronly (#164350)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164350 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
e5a9c247bc
commit
48064acf37
|
|
@ -28,101 +28,8 @@
|
|||
|
||||
namespace c10 {
|
||||
|
||||
// [dtype Macros note] For the macros below:
|
||||
//
|
||||
// For users: If you want to macro some code for all non-QInt scalar types
|
||||
// (i.e. types with complete information, you probably want one of the
|
||||
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
|
||||
// designed to behave similarly to the Dispatch macros with the same name.
|
||||
//
|
||||
// For adding a new dtype: In the beginning, we had an idea that there was a
|
||||
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
|
||||
// iterate over them. But over the years we added weird types which couldn't
|
||||
// be handled uniformly everywhere and so in the end we ended up with some
|
||||
// mish-mosh of some helper macros, but mostly use sites making a call about
|
||||
// what dtypes they can or can't support. So if you want to add a new dtype,
|
||||
// the preferred resolution is to find a dtype similar to what you want,
|
||||
// grep for it and edit all the sites you find this way. If you need to add
|
||||
// a completely new kind of dtype, you're going to have to laboriously audit
|
||||
// all of the sites everywhere to figure out how it should work. Consulting
|
||||
// some old PRs where we added new dtypes (check history of this file) can
|
||||
// help give you an idea where to start.
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
// doesn't work for all the conversions you need...
|
||||
//
|
||||
// TODO: To add unsigned int types here, we must define accumulate type.
|
||||
// But uint8 currently accumulates into int64, so we would have to make
|
||||
// an inconsistent choice for the larger types. Difficult.
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(at::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
||||
|
||||
// This macro controls many of our C++ APIs, including constructors
|
||||
// for Scalar as well as the data() and item() accessors on Tensor
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(at::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
namespace impl {
|
||||
|
||||
// These are used to map ScalarTypes to C++ types.
|
||||
|
||||
template <c10::ScalarType N>
|
||||
struct ScalarTypeToCPPType;
|
||||
|
||||
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
|
||||
using type = cpp_type; \
|
||||
\
|
||||
/* This is a workaround for the CUDA bug which prevents */ \
|
||||
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
|
||||
/* ambiguous reference which can't to be resolved. For some reason it */ \
|
||||
/* can't pick between at::detail and at::cuda::detail. */ \
|
||||
/* For repro example, please see: */ \
|
||||
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
|
||||
/* TODO: remove once the bug is fixed. */ \
|
||||
static type t; \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
|
||||
|
||||
#undef SPECIALIZE_ScalarTypeToCPPType
|
||||
|
||||
template <c10::ScalarType N>
|
||||
using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
|
||||
|
||||
} // namespace impl
|
||||
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
|
||||
// regarding macros.
|
||||
|
||||
template <typename T>
|
||||
struct CppTypeToScalarType;
|
||||
|
|
@ -138,130 +45,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
|||
|
||||
#undef SPECIALIZE_CppTypeToScalarType
|
||||
|
||||
// NB: despite its generic sounding name, the macros that don't take _AND
|
||||
// are mostly only used by tensorexpr
|
||||
#define AT_FORALL_INT_TYPES(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double)
|
||||
|
||||
// These macros are often controlling how many template instantiations we
|
||||
// create for kernels. It is typically inappropriate to add new dtypes here,
|
||||
// instead, new types should be added to use sites on a case-by-case basis.
|
||||
// We generally are not accepting new dtypes due to binary size concerns.
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE>::t), \
|
||||
SCALARTYPE)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND7( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
SCALARTYPE7, \
|
||||
_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE4>::t), \
|
||||
SCALARTYPE4) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE5>::t), \
|
||||
SCALARTYPE5) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE6>::t), \
|
||||
SCALARTYPE6) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE7>::t), \
|
||||
SCALARTYPE7)
|
||||
|
||||
#define AT_FORALL_QINT_TYPES(_) \
|
||||
_(c10::qint8, QInt8) \
|
||||
_(c10::quint8, QUInt8) \
|
||||
_(c10::qint32, QInt32) \
|
||||
_(c10::quint4x2, QUInt4x2) \
|
||||
_(c10::quint2x4, QUInt2x4)
|
||||
|
||||
#define AT_FORALL_FLOAT8_TYPES(_) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
#define AT_FORALL_COMPLEX_TYPES(_) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble)
|
||||
|
||||
#define DEFINE_CONSTANT(_, name) \
|
||||
constexpr ScalarType k##name = ScalarType::name;
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
|||
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp
|
||||
)
|
||||
|
|
|
|||
55
test/cpp/aoti_abi_check/test_scalartype.cpp
Normal file
55
test/cpp/aoti_abi_check/test_scalartype.cpp
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
TEST(TestScalarType, ScalarTypeToCPPTypeT) {
|
||||
using torch::headeronly::ScalarType;
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT;
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
EXPECT_EQ(typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE));
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
{ \
|
||||
EXPECT_EQ( \
|
||||
typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE)); \
|
||||
count++; \
|
||||
}
|
||||
|
||||
#define TEST_FORALL(M, EXPECTEDCOUNT, ...) \
|
||||
TEST(TestScalarType, M) { \
|
||||
using torch::headeronly::ScalarType; \
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
|
||||
int8_t count = 0; \
|
||||
M(__VA_ARGS__ DEFINE_CHECK); \
|
||||
EXPECT_EQ(count, EXPECTEDCOUNT); \
|
||||
}
|
||||
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ, 14)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX, 18)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS, 46)
|
||||
TEST_FORALL(AT_FORALL_INT_TYPES, 5)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES, 7)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND, 8, Bool, )
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND2, 9, Bool, Half, )
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND3, 10, Bool, Half, ComplexFloat, )
|
||||
TEST_FORALL(
|
||||
AT_FORALL_SCALAR_TYPES_AND7,
|
||||
14,
|
||||
Bool,
|
||||
Half,
|
||||
ComplexHalf,
|
||||
ComplexFloat,
|
||||
ComplexDouble,
|
||||
UInt16,
|
||||
UInt32, )
|
||||
TEST_FORALL(AT_FORALL_QINT_TYPES, 5)
|
||||
TEST_FORALL(AT_FORALL_FLOAT8_TYPES, 5)
|
||||
TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
|
||||
|
||||
#undef DEFINE_CHECK
|
||||
#undef TEST_FORALL
|
||||
|
|
@ -120,3 +120,16 @@ COMPILE_TIME_MAX_DEVICE_TYPES
|
|||
NumScalarTypes
|
||||
ScalarType
|
||||
# dummy_int1_7_t, dummy_uint1_7_t tested through ScalarType
|
||||
ScalarTypeToCPPTypeT
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
|
||||
AT_FORALL_INT_TYPES
|
||||
AT_FORALL_SCALAR_TYPES
|
||||
AT_FORALL_SCALAR_TYPES_AND
|
||||
AT_FORALL_SCALAR_TYPES_AND2
|
||||
AT_FORALL_SCALAR_TYPES_AND3
|
||||
AT_FORALL_SCALAR_TYPES_AND7
|
||||
AT_FORALL_QINT_TYPES
|
||||
AT_FORALL_FLOAT8_TYPES
|
||||
AT_FORALL_COMPLEX_TYPES
|
||||
|
|
|
|||
|
|
@ -30,7 +30,70 @@ struct dummy_uint1_7_t {};
|
|||
template <unsigned int N>
|
||||
struct dummy_int1_7_t {};
|
||||
|
||||
// See [dtype Macros note] in c10/core/ScalarType.h regarding macros
|
||||
// [dtype Macros note] For the macros below:
|
||||
//
|
||||
// For users: If you want to macro some code for all non-QInt scalar types
|
||||
// (i.e. types with complete information, you probably want one of the
|
||||
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
|
||||
// designed to behave similarly to the Dispatch macros with the same name.
|
||||
//
|
||||
// For adding a new dtype: In the beginning, we had an idea that there was a
|
||||
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
|
||||
// iterate over them. But over the years we added weird types which couldn't
|
||||
// be handled uniformly everywhere and so in the end we ended up with some
|
||||
// mish-mosh of some helper macros, but mostly use sites making a call about
|
||||
// what dtypes they can or can't support. So if you want to add a new dtype,
|
||||
// the preferred resolution is to find a dtype similar to what you want,
|
||||
// grep for it and edit all the sites you find this way. If you need to add
|
||||
// a completely new kind of dtype, you're going to have to laboriously audit
|
||||
// all of the sites everywhere to figure out how it should work. Consulting
|
||||
// some old PRs where we added new dtypes (check history of this file) can
|
||||
// help give you an idea where to start.
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
// doesn't work for all the conversions you need...
|
||||
//
|
||||
// TODO: To add unsigned int types here, we must define accumulate type.
|
||||
// But uint8 currently accumulates into int64, so we would have to make
|
||||
// an inconsistent choice for the larger types. Difficult.
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(at::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
||||
|
||||
// This macro controls many of our C++ APIs, including constructors
|
||||
// for Scalar as well as the data() and item() accessors on Tensor
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(at::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
// NB: Order matters for this macro; it is relied upon in
|
||||
// _promoteTypesLookup and the serialization format.
|
||||
|
|
@ -82,6 +145,130 @@ struct dummy_int1_7_t {};
|
|||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
|
||||
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
|
||||
|
||||
// NB: despite its generic sounding name, the macros that don't take _AND
|
||||
// are mostly only used by tensorexpr
|
||||
#define AT_FORALL_INT_TYPES(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double)
|
||||
|
||||
// These macros are often controlling how many template instantiations we
|
||||
// create for kernels. It is typically inappropriate to add new dtypes here,
|
||||
// instead, new types should be added to use sites on a case-by-case basis.
|
||||
// We generally are not accepting new dtypes due to binary size concerns.
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE>::t), \
|
||||
SCALARTYPE)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND7( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
SCALARTYPE7, \
|
||||
_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE4>::t), \
|
||||
SCALARTYPE4) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE5>::t), \
|
||||
SCALARTYPE5) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE6>::t), \
|
||||
SCALARTYPE6) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE7>::t), \
|
||||
SCALARTYPE7)
|
||||
|
||||
#define AT_FORALL_QINT_TYPES(_) \
|
||||
_(c10::qint8, QInt8) \
|
||||
_(c10::quint8, QUInt8) \
|
||||
_(c10::qint32, QInt32) \
|
||||
_(c10::quint4x2, QUInt4x2) \
|
||||
_(c10::quint2x4, QUInt2x4)
|
||||
|
||||
#define AT_FORALL_FLOAT8_TYPES(_) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
#define AT_FORALL_COMPLEX_TYPES(_) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble)
|
||||
|
||||
enum class ScalarType : int8_t {
|
||||
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
|
||||
|
|
@ -93,6 +280,37 @@ enum class ScalarType : int8_t {
|
|||
constexpr uint16_t NumScalarTypes =
|
||||
static_cast<uint16_t>(ScalarType::NumOptions);
|
||||
|
||||
namespace impl {
|
||||
|
||||
// These are used to map ScalarTypes to C++ types.
|
||||
|
||||
template <c10::ScalarType N>
|
||||
struct ScalarTypeToCPPType;
|
||||
|
||||
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
|
||||
using type = cpp_type; \
|
||||
\
|
||||
/* This is a workaround for the CUDA bug which prevents */ \
|
||||
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
|
||||
/* ambiguous reference which can't to be resolved. For some reason it */ \
|
||||
/* can't pick between at::detail and at::cuda::detail. */ \
|
||||
/* For repro example, please see: */ \
|
||||
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
|
||||
/* TODO: remove once the bug is fixed. */ \
|
||||
static type t; \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
|
||||
|
||||
#undef SPECIALIZE_ScalarTypeToCPPType
|
||||
|
||||
template <c10::ScalarType N>
|
||||
using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
|
||||
|
||||
} // namespace impl
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::headeronly {
|
||||
|
|
@ -100,4 +318,7 @@ using c10::dummy_int1_7_t;
|
|||
using c10::dummy_uint1_7_t;
|
||||
using c10::NumScalarTypes;
|
||||
using c10::ScalarType;
|
||||
namespace impl {
|
||||
using c10::impl::ScalarTypeToCPPTypeT;
|
||||
} // namespace impl
|
||||
} // namespace torch::headeronly
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user