#include #include #include #include "caffe2/utils/cast.h" namespace caffe2 { TEST(CastTest, GetCastDataType) { auto castOp = [](std::string t) { // Ensure lowercase. std::transform(t.begin(), t.end(), t.begin(), ::tolower); auto op = CreateOperatorDef("Cast", "", {}, {}); AddArgument("to", t, &op); return op; }; #define X(t) \ EXPECT_EQ( \ TensorProto_DataType_##t, \ cast::GetCastDataType(ArgumentHelper(castOp(#t)), "to")); X(FLOAT); X(INT32); X(BYTE); X(STRING); X(BOOL); X(UINT8); X(INT8); X(UINT16); X(INT16); X(INT64); X(FLOAT16); X(DOUBLE); #undef X } } // namespace caffe2