mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20511 Removed cyclic dependency of caffe2/core/net.h and workspace.h Differential Revision: D15303412 fbshipit-source-id: 6e772e372cd0cf2af05d7815f1df8ae20bc2a65e
127 lines
3.7 KiB
C++
127 lines
3.7 KiB
C++
#include "caffe2/predictor/predictor.h"
|
|
#include <unordered_set>
|
|
#include "caffe2/core/init.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
class Workspace;
|
|
namespace {
|
|
|
|
void enforceIsTensor(Workspace* ws, const std::string& name) {
|
|
auto blob = ws->GetBlob(name);
|
|
CAFFE_ENFORCE(blob, "Blob does not exist: ", name);
|
|
CAFFE_ENFORCE(
|
|
BlobIsTensorType(*blob, CPU), "Blob is not a CPU Tensor: ", name);
|
|
}
|
|
|
|
Blob* getBlob(Workspace* ws, const std::string& name) {
|
|
enforceIsTensor(ws, name);
|
|
auto* blob = ws->GetBlob(name);
|
|
CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
|
|
return blob;
|
|
}
|
|
|
|
const Tensor& getTensor(Workspace* ws, const std::string& name) {
|
|
return *BlobGetMutableTensor(getBlob(ws, name), CPU);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
Predictor::Predictor(
|
|
const NetDef& init_net,
|
|
const NetDef& run_net,
|
|
Workspace* parent,
|
|
bool run_init,
|
|
int optimization)
|
|
: Predictor(makePredictorConfig(
|
|
init_net,
|
|
run_net,
|
|
parent,
|
|
run_init,
|
|
optimization)) {}
|
|
|
|
Predictor::Predictor(PredictorConfig config) : config_(std::move(config)) {
|
|
const auto& initialized_vec = config_.ws->Blobs();
|
|
const std::unordered_set<std::string> initialized{initialized_vec.begin(),
|
|
initialized_vec.end()};
|
|
for (const auto& name : config_.predict_net->external_input()) {
|
|
if (!initialized.count(name)) {
|
|
auto* blob = config_.ws->CreateBlob(name);
|
|
BlobGetMutableTensor(blob, CPU);
|
|
}
|
|
}
|
|
CAFFE_ENFORCE(config_.ws->CreateNet(config_.predict_net));
|
|
}
|
|
|
|
bool Predictor::operator()(const TensorList& inputs, TensorList* outputs) {
|
|
CAFFE_ENFORCE(
|
|
inputs.size() <=
|
|
static_cast<unsigned>(config_.predict_net->external_input_size()));
|
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
// This is evil and shares the same underlying tensor
|
|
BlobSetTensor(
|
|
getBlob(config_.ws.get(), config_.predict_net->external_input(i)),
|
|
inputs[i].UnsafeSharedInstance());
|
|
}
|
|
|
|
if (!config_.ws->RunNet(config_.predict_net->name())) {
|
|
return false;
|
|
}
|
|
outputs->clear();
|
|
for (size_t i = 0; i < config_.predict_net->external_output_size(); ++i) {
|
|
outputs->emplace_back(
|
|
getTensor(config_.ws.get(), config_.predict_net->external_output(i))
|
|
.UnsafeSharedInstance());
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool Predictor::run_map_workspace(const TensorMap& inputs) {
|
|
if (!config_.input_names.empty()) {
|
|
CAFFE_ENFORCE_EQ(inputs.size(), input_names().size());
|
|
}
|
|
for (auto& input : inputs) {
|
|
if (!input_names().empty()) {
|
|
CAFFE_ENFORCE(
|
|
std::find(input_names().begin(), input_names().end(), input.first) !=
|
|
input_names().end(),
|
|
"Input can't be found: ",
|
|
input.first);
|
|
}
|
|
// This is evil and shares the same underlying tensor
|
|
BlobSetTensor(
|
|
getBlob(config_.ws.get(), input.first),
|
|
input.second.UnsafeSharedInstance());
|
|
}
|
|
|
|
return config_.ws->RunNet(config_.predict_net->name());
|
|
}
|
|
|
|
bool Predictor::operator()(const TensorMap& inputs, TensorList* outputs) {
|
|
if (!run_map_workspace(inputs)) {
|
|
return false;
|
|
}
|
|
outputs->clear();
|
|
for (size_t i = 0; i < config_.predict_net->external_output_size(); ++i) {
|
|
outputs->push_back(
|
|
getTensor(config_.ws.get(), config_.predict_net->external_output(i))
|
|
.UnsafeSharedInstance());
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool Predictor::operator()(const TensorMap& inputs, TensorMap* outputs) {
|
|
if (!run_map_workspace(inputs)) {
|
|
return false;
|
|
}
|
|
|
|
for (const std::string& outputName : output_names()) {
|
|
outputs->emplace(
|
|
outputName,
|
|
getTensor(config_.ws.get(), outputName).UnsafeSharedInstance());
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace caffe2
|