[caffe2] Avoid some double (and triple) lookups in workspace (#53319)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53319

Noticed these in profiles.

Also switch to `unordered_map`.

Test Plan: Unit tests.

Reviewed By: swolchok

Differential Revision: D26504408

fbshipit-source-id: 9e14d55909a4af019058b8c27c67ee2348cd02a9
This commit is contained in:
Giuseppe Ottaviano 2021-03-04 22:54:26 -08:00 committed by Facebook GitHub Bot
parent 35364c3641
commit 69bb0e0285
2 changed files with 51 additions and 31 deletions

View File

@ -84,7 +84,7 @@ vector<string> Workspace::Blobs() const {
names.push_back(entry.first);
}
for (const auto& forwarded : forwarded_blobs_) {
const auto parent_ws = forwarded.second.first;
const auto* parent_ws = forwarded.second.first;
const auto& parent_name = forwarded.second.second;
if (parent_ws->HasBlob(parent_name)) {
names.push_back(forwarded.first);
@ -112,13 +112,14 @@ Blob* Workspace::CreateBlob(const string& name) {
}
Blob* Workspace::CreateLocalBlob(const string& name) {
if (blob_map_.count(name)) {
auto p = blob_map_.emplace(name, nullptr);
if (!p.second) {
VLOG(1) << "Blob " << name << " already exists. Skipping.";
} else {
VLOG(1) << "Creating blob " << name;
blob_map_[name] = unique_ptr<Blob>(new Blob());
p.first->second = std::make_unique<Blob>();
}
return GetBlob(name);
return p.first->second.get();
}
Blob* Workspace::RenameBlob(const string& old_name, const string& new_name) {
@ -158,15 +159,28 @@ bool Workspace::RemoveBlob(const string& name) {
}
const Blob* Workspace::GetBlob(const string& name) const {
if (blob_map_.count(name)) {
return blob_map_.at(name).get();
} else if (forwarded_blobs_.count(name)) {
const auto parent_ws = forwarded_blobs_.at(name).first;
const auto& parent_name = forwarded_blobs_.at(name).second;
return parent_ws->GetBlob(parent_name);
} else if (shared_ && shared_->HasBlob(name)) {
return shared_->GetBlob(name);
{
auto it = blob_map_.find(name);
if (it != blob_map_.end()) {
return it->second.get();
}
}
{
auto it = forwarded_blobs_.find(name);
if (it != forwarded_blobs_.end()) {
const auto* parent_ws = it->second.first;
const auto& parent_name = it->second.second;
return parent_ws->GetBlob(parent_name);
}
}
if (shared_) {
if (auto blob = shared_->GetBlob(name)) {
return blob;
}
}
LOG(WARNING) << "Blob " << name << " not in the workspace.";
// TODO(Yangqing): do we want to always print out the list of blobs here?
// LOG(WARNING) << "Current blobs:";
@ -249,25 +263,25 @@ NetBase* Workspace::CreateNet(
}
NetBase* Workspace::GetNet(const string& name) {
if (!net_map_.count(name)) {
return nullptr;
} else {
return net_map_[name].get();
auto it = net_map_.find(name);
if (it != net_map_.end()) {
return it->second.get();
}
return nullptr;
}
void Workspace::DeleteNet(const string& name) {
if (net_map_.count(name)) {
net_map_.erase(name);
}
net_map_.erase(name);
}
bool Workspace::RunNet(const string& name) {
if (!net_map_.count(name)) {
auto it = net_map_.find(name);
if (it == net_map_.end()) {
LOG(ERROR) << "Network " << name << " does not exist yet.";
return false;
}
return net_map_[name]->Run();
return it->second->Run();
}
bool Workspace::RunOperatorOnce(const OperatorDef& op_def) {

View File

@ -8,6 +8,7 @@
#include <cstddef>
#include <mutex>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@ -47,8 +48,6 @@ struct TORCH_API StopOnSignal {
class TORCH_API Workspace {
public:
typedef std::function<bool(int)> ShouldContinue;
typedef CaffeMap<string, unique_ptr<Blob> > BlobMap;
typedef CaffeMap<string, unique_ptr<NetBase> > NetMap;
/**
* Initializes an empty workspace.
*/
@ -136,10 +135,11 @@ class TORCH_API Workspace {
template <class Context>
void CopyForwardedTensors(const std::unordered_set<std::string>& blobs) {
for (const auto& blob : blobs) {
if (!forwarded_blobs_.count(blob)) {
auto it = forwarded_blobs_.find(blob);
if (it == forwarded_blobs_.end()) {
continue;
}
const auto& ws_blob = forwarded_blobs_[blob];
const auto& ws_blob = it->second;
const auto* parent_ws = ws_blob.first;
auto* from_blob = parent_ws->GetBlob(ws_blob.second);
CAFFE_ENFORCE(from_blob);
@ -181,13 +181,19 @@ class TORCH_API Workspace {
// Then, check the forwarding map, then the parent workspace
if (blob_map_.count(name)) {
return true;
} else if (forwarded_blobs_.count(name)) {
const auto parent_ws = forwarded_blobs_.at(name).first;
const auto& parent_name = forwarded_blobs_.at(name).second;
}
auto it = forwarded_blobs_.find(name);
if (it != forwarded_blobs_.end()) {
const auto parent_ws = it->second.first;
const auto& parent_name = it->second.second;
return parent_ws->HasBlob(parent_name);
} else if (shared_) {
}
if (shared_) {
return shared_->HasBlob(name);
}
return false;
}
@ -318,7 +324,7 @@ class TORCH_API Workspace {
static std::shared_ptr<Bookkeeper> bookkeeper();
BlobMap blob_map_;
std::unordered_map<string, unique_ptr<Blob>> blob_map_;
const string root_folder_;
const Workspace* shared_;
std::unordered_map<string, std::pair<const Workspace*, string>>
@ -326,7 +332,7 @@ class TORCH_API Workspace {
std::unique_ptr<ThreadPool> thread_pool_;
std::mutex thread_pool_creation_mutex_;
std::shared_ptr<Bookkeeper> bookkeeper_;
NetMap net_map_;
std::unordered_map<string, unique_ptr<NetBase>> net_map_;
C10_DISABLE_COPY_AND_ASSIGN(Workspace);
};