[caffe2] Avoid some double (and triple) lookups in workspace (#53319)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53319

Noticed these in profiles.

Also switch to `unordered_map`.

Test Plan: Unit tests.

Reviewed By: swolchok

Differential Revision: D26504408

fbshipit-source-id: 9e14d55909a4af019058b8c27c67ee2348cd02a9
This commit is contained in:
Giuseppe Ottaviano 2021-03-04 22:54:26 -08:00 committed by Facebook GitHub Bot
parent 35364c3641
commit 69bb0e0285
2 changed files with 51 additions and 31 deletions

View File

@ -84,7 +84,7 @@ vector<string> Workspace::Blobs() const {
names.push_back(entry.first); names.push_back(entry.first);
} }
for (const auto& forwarded : forwarded_blobs_) { for (const auto& forwarded : forwarded_blobs_) {
const auto parent_ws = forwarded.second.first; const auto* parent_ws = forwarded.second.first;
const auto& parent_name = forwarded.second.second; const auto& parent_name = forwarded.second.second;
if (parent_ws->HasBlob(parent_name)) { if (parent_ws->HasBlob(parent_name)) {
names.push_back(forwarded.first); names.push_back(forwarded.first);
@ -112,13 +112,14 @@ Blob* Workspace::CreateBlob(const string& name) {
} }
Blob* Workspace::CreateLocalBlob(const string& name) { Blob* Workspace::CreateLocalBlob(const string& name) {
if (blob_map_.count(name)) { auto p = blob_map_.emplace(name, nullptr);
if (!p.second) {
VLOG(1) << "Blob " << name << " already exists. Skipping."; VLOG(1) << "Blob " << name << " already exists. Skipping.";
} else { } else {
VLOG(1) << "Creating blob " << name; VLOG(1) << "Creating blob " << name;
blob_map_[name] = unique_ptr<Blob>(new Blob()); p.first->second = std::make_unique<Blob>();
} }
return GetBlob(name); return p.first->second.get();
} }
Blob* Workspace::RenameBlob(const string& old_name, const string& new_name) { Blob* Workspace::RenameBlob(const string& old_name, const string& new_name) {
@ -158,15 +159,28 @@ bool Workspace::RemoveBlob(const string& name) {
} }
const Blob* Workspace::GetBlob(const string& name) const { const Blob* Workspace::GetBlob(const string& name) const {
if (blob_map_.count(name)) { {
return blob_map_.at(name).get(); auto it = blob_map_.find(name);
} else if (forwarded_blobs_.count(name)) { if (it != blob_map_.end()) {
const auto parent_ws = forwarded_blobs_.at(name).first; return it->second.get();
const auto& parent_name = forwarded_blobs_.at(name).second;
return parent_ws->GetBlob(parent_name);
} else if (shared_ && shared_->HasBlob(name)) {
return shared_->GetBlob(name);
} }
}
{
auto it = forwarded_blobs_.find(name);
if (it != forwarded_blobs_.end()) {
const auto* parent_ws = it->second.first;
const auto& parent_name = it->second.second;
return parent_ws->GetBlob(parent_name);
}
}
if (shared_) {
if (auto blob = shared_->GetBlob(name)) {
return blob;
}
}
LOG(WARNING) << "Blob " << name << " not in the workspace."; LOG(WARNING) << "Blob " << name << " not in the workspace.";
// TODO(Yangqing): do we want to always print out the list of blobs here? // TODO(Yangqing): do we want to always print out the list of blobs here?
// LOG(WARNING) << "Current blobs:"; // LOG(WARNING) << "Current blobs:";
@ -249,25 +263,25 @@ NetBase* Workspace::CreateNet(
} }
NetBase* Workspace::GetNet(const string& name) { NetBase* Workspace::GetNet(const string& name) {
if (!net_map_.count(name)) { auto it = net_map_.find(name);
return nullptr; if (it != net_map_.end()) {
} else { return it->second.get();
return net_map_[name].get();
} }
return nullptr;
} }
void Workspace::DeleteNet(const string& name) { void Workspace::DeleteNet(const string& name) {
if (net_map_.count(name)) {
net_map_.erase(name); net_map_.erase(name);
}
} }
bool Workspace::RunNet(const string& name) { bool Workspace::RunNet(const string& name) {
if (!net_map_.count(name)) { auto it = net_map_.find(name);
if (it == net_map_.end()) {
LOG(ERROR) << "Network " << name << " does not exist yet."; LOG(ERROR) << "Network " << name << " does not exist yet.";
return false; return false;
} }
return net_map_[name]->Run(); return it->second->Run();
} }
bool Workspace::RunOperatorOnce(const OperatorDef& op_def) { bool Workspace::RunOperatorOnce(const OperatorDef& op_def) {

View File

@ -8,6 +8,7 @@
#include <cstddef> #include <cstddef>
#include <mutex> #include <mutex>
#include <typeinfo> #include <typeinfo>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
@ -47,8 +48,6 @@ struct TORCH_API StopOnSignal {
class TORCH_API Workspace { class TORCH_API Workspace {
public: public:
typedef std::function<bool(int)> ShouldContinue; typedef std::function<bool(int)> ShouldContinue;
typedef CaffeMap<string, unique_ptr<Blob> > BlobMap;
typedef CaffeMap<string, unique_ptr<NetBase> > NetMap;
/** /**
* Initializes an empty workspace. * Initializes an empty workspace.
*/ */
@ -136,10 +135,11 @@ class TORCH_API Workspace {
template <class Context> template <class Context>
void CopyForwardedTensors(const std::unordered_set<std::string>& blobs) { void CopyForwardedTensors(const std::unordered_set<std::string>& blobs) {
for (const auto& blob : blobs) { for (const auto& blob : blobs) {
if (!forwarded_blobs_.count(blob)) { auto it = forwarded_blobs_.find(blob);
if (it == forwarded_blobs_.end()) {
continue; continue;
} }
const auto& ws_blob = forwarded_blobs_[blob]; const auto& ws_blob = it->second;
const auto* parent_ws = ws_blob.first; const auto* parent_ws = ws_blob.first;
auto* from_blob = parent_ws->GetBlob(ws_blob.second); auto* from_blob = parent_ws->GetBlob(ws_blob.second);
CAFFE_ENFORCE(from_blob); CAFFE_ENFORCE(from_blob);
@ -181,13 +181,19 @@ class TORCH_API Workspace {
// Then, check the forwarding map, then the parent workspace // Then, check the forwarding map, then the parent workspace
if (blob_map_.count(name)) { if (blob_map_.count(name)) {
return true; return true;
} else if (forwarded_blobs_.count(name)) { }
const auto parent_ws = forwarded_blobs_.at(name).first;
const auto& parent_name = forwarded_blobs_.at(name).second; auto it = forwarded_blobs_.find(name);
if (it != forwarded_blobs_.end()) {
const auto parent_ws = it->second.first;
const auto& parent_name = it->second.second;
return parent_ws->HasBlob(parent_name); return parent_ws->HasBlob(parent_name);
} else if (shared_) { }
if (shared_) {
return shared_->HasBlob(name); return shared_->HasBlob(name);
} }
return false; return false;
} }
@ -318,7 +324,7 @@ class TORCH_API Workspace {
static std::shared_ptr<Bookkeeper> bookkeeper(); static std::shared_ptr<Bookkeeper> bookkeeper();
BlobMap blob_map_; std::unordered_map<string, unique_ptr<Blob>> blob_map_;
const string root_folder_; const string root_folder_;
const Workspace* shared_; const Workspace* shared_;
std::unordered_map<string, std::pair<const Workspace*, string>> std::unordered_map<string, std::pair<const Workspace*, string>>
@ -326,7 +332,7 @@ class TORCH_API Workspace {
std::unique_ptr<ThreadPool> thread_pool_; std::unique_ptr<ThreadPool> thread_pool_;
std::mutex thread_pool_creation_mutex_; std::mutex thread_pool_creation_mutex_;
std::shared_ptr<Bookkeeper> bookkeeper_; std::shared_ptr<Bookkeeper> bookkeeper_;
NetMap net_map_; std::unordered_map<string, unique_ptr<NetBase>> net_map_;
C10_DISABLE_COPY_AND_ASSIGN(Workspace); C10_DISABLE_COPY_AND_ASSIGN(Workspace);
}; };