#include "caffe2/predictor/predictor_utils.h" #include "caffe2/core/blob.h" #include "caffe2/core/logging.h" #include "caffe2/proto/caffe2_pb.h" #include "caffe2/proto/predictor_consts.pb.h" #include "caffe2/utils/proto_utils.h" namespace caffe2 { namespace predictor_utils { TORCH_API const NetDef& getNet(const MetaNetDef& def, const std::string& name) { std::string net_names; bool is_first = true; for (const auto& n : def.nets()) { if (!is_first) { net_names += ", "; } is_first = false; net_names += n.key(); if (n.key() == name) { return n.value(); } } CAFFE_THROW("Net not found: ", name, "; available nets: ", net_names); } std::unique_ptr extractMetaNetDef( db::Cursor* cursor, const std::string& key) { CAFFE_ENFORCE(cursor); if (cursor->SupportsSeek()) { cursor->Seek(key); } for (; cursor->Valid(); cursor->Next()) { if (cursor->key() != key) { continue; } // We've found a match. Parse it out. BlobProto proto; CAFFE_ENFORCE(proto.ParseFromString(cursor->value())); Blob blob; DeserializeBlob(proto, &blob); CAFFE_ENFORCE(blob.template IsType()); auto def = std::make_unique(); CAFFE_ENFORCE(def->ParseFromString(blob.template Get())); return def; } CAFFE_THROW("Failed to find in db the key: ", key); } std::unique_ptr runGlobalInitialization( std::unique_ptr db, Workspace* master) { CAFFE_ENFORCE(db.get()); auto* cursor = db->cursor(); auto metaNetDef = extractMetaNetDef( cursor, PredictorConsts::default_instance().meta_net_def()); if (metaNetDef->has_modelinfo()) { CAFFE_ENFORCE( metaNetDef->modelinfo().predictortype() == PredictorConsts::default_instance().single_predictor(), "Can only load single predictor"); } VLOG(1) << "Extracted meta net def"; // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) const auto globalInitNet = getNet( *metaNetDef, PredictorConsts::default_instance().global_init_net_type()); VLOG(1) << "Global init net: " << ProtoDebugString(globalInitNet); // Now, pass away ownership of the DB into the master workspace for // use by the globalInitNet. master->CreateBlob(PredictorConsts::default_instance().predictor_dbreader()) ->Reset(db.release()); // Now, with the DBReader set, we can run globalInitNet. CAFFE_ENFORCE( master->RunNetOnce(globalInitNet), "Failed running the globalInitNet: ", ProtoDebugString(globalInitNet)); return metaNetDef; } } // namespace predictor_utils } // namespace caffe2