mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: print available modules when throwing a module not found exception I believe that improves UX Differential Revision: D36580924 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78101 Approved by: https://github.com/mikeiovine
92 lines
2.6 KiB
C++
92 lines
2.6 KiB
C++
#include "caffe2/predictor/predictor_utils.h"
|
|
|
|
#include "caffe2/core/blob.h"
|
|
#include "caffe2/core/logging.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
#include "caffe2/proto/predictor_consts.pb.h"
|
|
#include "caffe2/utils/proto_utils.h"
|
|
|
|
namespace caffe2 {
|
|
namespace predictor_utils {
|
|
|
|
TORCH_API const NetDef& getNet(const MetaNetDef& def, const std::string& name) {
|
|
std::string net_names;
|
|
bool is_first = true;
|
|
for (const auto& n : def.nets()) {
|
|
if (!is_first) {
|
|
net_names += ", ";
|
|
}
|
|
is_first = false;
|
|
net_names += n.key();
|
|
if (n.key() == name) {
|
|
return n.value();
|
|
}
|
|
}
|
|
CAFFE_THROW("Net not found: ",
|
|
name,
|
|
"; available nets: ",
|
|
net_names);
|
|
}
|
|
|
|
std::unique_ptr<MetaNetDef> extractMetaNetDef(
|
|
db::Cursor* cursor,
|
|
const std::string& key) {
|
|
CAFFE_ENFORCE(cursor);
|
|
if (cursor->SupportsSeek()) {
|
|
cursor->Seek(key);
|
|
}
|
|
for (; cursor->Valid(); cursor->Next()) {
|
|
if (cursor->key() != key) {
|
|
continue;
|
|
}
|
|
// We've found a match. Parse it out.
|
|
BlobProto proto;
|
|
CAFFE_ENFORCE(proto.ParseFromString(cursor->value()));
|
|
Blob blob;
|
|
DeserializeBlob(proto, &blob);
|
|
CAFFE_ENFORCE(blob.template IsType<string>());
|
|
auto def = std::make_unique<MetaNetDef>();
|
|
CAFFE_ENFORCE(def->ParseFromString(blob.template Get<string>()));
|
|
return def;
|
|
}
|
|
CAFFE_THROW("Failed to find in db the key: ", key);
|
|
}
|
|
|
|
std::unique_ptr<MetaNetDef> runGlobalInitialization(
|
|
std::unique_ptr<db::DBReader> db,
|
|
Workspace* master) {
|
|
CAFFE_ENFORCE(db.get());
|
|
auto* cursor = db->cursor();
|
|
|
|
auto metaNetDef = extractMetaNetDef(
|
|
cursor, PredictorConsts::default_instance().meta_net_def());
|
|
if (metaNetDef->has_modelinfo()) {
|
|
CAFFE_ENFORCE(
|
|
metaNetDef->modelinfo().predictortype() ==
|
|
PredictorConsts::default_instance().single_predictor(),
|
|
"Can only load single predictor");
|
|
}
|
|
VLOG(1) << "Extracted meta net def";
|
|
|
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
|
const auto globalInitNet = getNet(
|
|
*metaNetDef, PredictorConsts::default_instance().global_init_net_type());
|
|
VLOG(1) << "Global init net: " << ProtoDebugString(globalInitNet);
|
|
|
|
// Now, pass away ownership of the DB into the master workspace for
|
|
// use by the globalInitNet.
|
|
master->CreateBlob(PredictorConsts::default_instance().predictor_dbreader())
|
|
->Reset(db.release());
|
|
|
|
// Now, with the DBReader set, we can run globalInitNet.
|
|
CAFFE_ENFORCE(
|
|
master->RunNetOnce(globalInitNet),
|
|
"Failed running the globalInitNet: ",
|
|
ProtoDebugString(globalInitNet));
|
|
|
|
return metaNetDef;
|
|
}
|
|
|
|
} // namespace predictor_utils
|
|
} // namespace caffe2
|