#include "pybind_state_dlpack.h" namespace caffe2 { namespace python { namespace py = pybind11; const DLDeviceType* CaffeToDLDeviceType(int device_type) { static std::map dl_device_type_map{ {PROTO_CPU, kDLCPU}, {PROTO_CUDA, kDLGPU}, }; const auto it = dl_device_type_map.find(device_type); return it == dl_device_type_map.end() ? nullptr : &it->second; } const DLDataType* CaffeToDLType(const TypeMeta& meta) { static std::map dl_type_map{ {TypeMeta::Id(), DLDataType{0, 8, 1}}, {TypeMeta::Id(), DLDataType{0, 16, 1}}, {TypeMeta::Id(), DLDataType{0, 32, 1}}, {TypeMeta::Id(), DLDataType{0, 64, 1}}, {TypeMeta::Id(), DLDataType{1, 8, 1}}, {TypeMeta::Id(), DLDataType{1, 16, 1}}, {TypeMeta::Id(), DLDataType{2, 16, 1}}, {TypeMeta::Id(), DLDataType{2, 32, 1}}, {TypeMeta::Id(), DLDataType{2, 64, 1}}, }; const auto it = dl_type_map.find(meta.id()); return it == dl_type_map.end() ? nullptr : &it->second; } const TypeMeta& DLTypeToCaffe(const DLDataType& dl_type) { try { if (dl_type.lanes != 1) { throw std::invalid_argument("invalid type"); } static std::map> dl_caffe_type_map{ {0, std::map{ {8, TypeMeta::Make()}, {16, TypeMeta::Make()}, {32, TypeMeta::Make()}, {64, TypeMeta::Make()}, }}, {1, std::map{ {8, TypeMeta::Make()}, {16, TypeMeta::Make()}, }}, {2, std::map{ {16, TypeMeta::Make()}, {32, TypeMeta::Make()}, {64, TypeMeta::Make()}, }}, }; if (!dl_caffe_type_map.count(dl_type.code)) { throw std::invalid_argument("invalid type"); } const auto& bits_map = dl_caffe_type_map.at(dl_type.code); if (!bits_map.count(dl_type.bits)) { throw std::invalid_argument("invalid type"); } return bits_map.at(dl_type.bits); } catch (const std::invalid_argument& e) { CAFFE_THROW( "Unsupported DLDataType: ", dl_type.code, dl_type.bits, dl_type.lanes); } } } // namespace python } // namespace caffe2