mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[caffe2] Avoid some double (and triple) lookups in workspace (#53319)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53319 Noticed these in profiles. Also switch to `unordered_map`. Test Plan: Unit tests. Reviewed By: swolchok Differential Revision: D26504408 fbshipit-source-id: 9e14d55909a4af019058b8c27c67ee2348cd02a9
This commit is contained in:
parent
35364c3641
commit
69bb0e0285
|
|
@ -84,7 +84,7 @@ vector<string> Workspace::Blobs() const {
|
|||
names.push_back(entry.first);
|
||||
}
|
||||
for (const auto& forwarded : forwarded_blobs_) {
|
||||
const auto parent_ws = forwarded.second.first;
|
||||
const auto* parent_ws = forwarded.second.first;
|
||||
const auto& parent_name = forwarded.second.second;
|
||||
if (parent_ws->HasBlob(parent_name)) {
|
||||
names.push_back(forwarded.first);
|
||||
|
|
@ -112,13 +112,14 @@ Blob* Workspace::CreateBlob(const string& name) {
|
|||
}
|
||||
|
||||
Blob* Workspace::CreateLocalBlob(const string& name) {
|
||||
if (blob_map_.count(name)) {
|
||||
auto p = blob_map_.emplace(name, nullptr);
|
||||
if (!p.second) {
|
||||
VLOG(1) << "Blob " << name << " already exists. Skipping.";
|
||||
} else {
|
||||
VLOG(1) << "Creating blob " << name;
|
||||
blob_map_[name] = unique_ptr<Blob>(new Blob());
|
||||
p.first->second = std::make_unique<Blob>();
|
||||
}
|
||||
return GetBlob(name);
|
||||
return p.first->second.get();
|
||||
}
|
||||
|
||||
Blob* Workspace::RenameBlob(const string& old_name, const string& new_name) {
|
||||
|
|
@ -158,15 +159,28 @@ bool Workspace::RemoveBlob(const string& name) {
|
|||
}
|
||||
|
||||
const Blob* Workspace::GetBlob(const string& name) const {
|
||||
if (blob_map_.count(name)) {
|
||||
return blob_map_.at(name).get();
|
||||
} else if (forwarded_blobs_.count(name)) {
|
||||
const auto parent_ws = forwarded_blobs_.at(name).first;
|
||||
const auto& parent_name = forwarded_blobs_.at(name).second;
|
||||
return parent_ws->GetBlob(parent_name);
|
||||
} else if (shared_ && shared_->HasBlob(name)) {
|
||||
return shared_->GetBlob(name);
|
||||
{
|
||||
auto it = blob_map_.find(name);
|
||||
if (it != blob_map_.end()) {
|
||||
return it->second.get();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
auto it = forwarded_blobs_.find(name);
|
||||
if (it != forwarded_blobs_.end()) {
|
||||
const auto* parent_ws = it->second.first;
|
||||
const auto& parent_name = it->second.second;
|
||||
return parent_ws->GetBlob(parent_name);
|
||||
}
|
||||
}
|
||||
|
||||
if (shared_) {
|
||||
if (auto blob = shared_->GetBlob(name)) {
|
||||
return blob;
|
||||
}
|
||||
}
|
||||
|
||||
LOG(WARNING) << "Blob " << name << " not in the workspace.";
|
||||
// TODO(Yangqing): do we want to always print out the list of blobs here?
|
||||
// LOG(WARNING) << "Current blobs:";
|
||||
|
|
@ -249,25 +263,25 @@ NetBase* Workspace::CreateNet(
|
|||
}
|
||||
|
||||
NetBase* Workspace::GetNet(const string& name) {
|
||||
if (!net_map_.count(name)) {
|
||||
return nullptr;
|
||||
} else {
|
||||
return net_map_[name].get();
|
||||
auto it = net_map_.find(name);
|
||||
if (it != net_map_.end()) {
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Workspace::DeleteNet(const string& name) {
|
||||
if (net_map_.count(name)) {
|
||||
net_map_.erase(name);
|
||||
}
|
||||
}
|
||||
|
||||
bool Workspace::RunNet(const string& name) {
|
||||
if (!net_map_.count(name)) {
|
||||
auto it = net_map_.find(name);
|
||||
if (it == net_map_.end()) {
|
||||
LOG(ERROR) << "Network " << name << " does not exist yet.";
|
||||
return false;
|
||||
}
|
||||
return net_map_[name]->Run();
|
||||
return it->second->Run();
|
||||
}
|
||||
|
||||
bool Workspace::RunOperatorOnce(const OperatorDef& op_def) {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <cstddef>
|
||||
#include <mutex>
|
||||
#include <typeinfo>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -47,8 +48,6 @@ struct TORCH_API StopOnSignal {
|
|||
class TORCH_API Workspace {
|
||||
public:
|
||||
typedef std::function<bool(int)> ShouldContinue;
|
||||
typedef CaffeMap<string, unique_ptr<Blob> > BlobMap;
|
||||
typedef CaffeMap<string, unique_ptr<NetBase> > NetMap;
|
||||
/**
|
||||
* Initializes an empty workspace.
|
||||
*/
|
||||
|
|
@ -136,10 +135,11 @@ class TORCH_API Workspace {
|
|||
template <class Context>
|
||||
void CopyForwardedTensors(const std::unordered_set<std::string>& blobs) {
|
||||
for (const auto& blob : blobs) {
|
||||
if (!forwarded_blobs_.count(blob)) {
|
||||
auto it = forwarded_blobs_.find(blob);
|
||||
if (it == forwarded_blobs_.end()) {
|
||||
continue;
|
||||
}
|
||||
const auto& ws_blob = forwarded_blobs_[blob];
|
||||
const auto& ws_blob = it->second;
|
||||
const auto* parent_ws = ws_blob.first;
|
||||
auto* from_blob = parent_ws->GetBlob(ws_blob.second);
|
||||
CAFFE_ENFORCE(from_blob);
|
||||
|
|
@ -181,13 +181,19 @@ class TORCH_API Workspace {
|
|||
// Then, check the forwarding map, then the parent workspace
|
||||
if (blob_map_.count(name)) {
|
||||
return true;
|
||||
} else if (forwarded_blobs_.count(name)) {
|
||||
const auto parent_ws = forwarded_blobs_.at(name).first;
|
||||
const auto& parent_name = forwarded_blobs_.at(name).second;
|
||||
}
|
||||
|
||||
auto it = forwarded_blobs_.find(name);
|
||||
if (it != forwarded_blobs_.end()) {
|
||||
const auto parent_ws = it->second.first;
|
||||
const auto& parent_name = it->second.second;
|
||||
return parent_ws->HasBlob(parent_name);
|
||||
} else if (shared_) {
|
||||
}
|
||||
|
||||
if (shared_) {
|
||||
return shared_->HasBlob(name);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -318,7 +324,7 @@ class TORCH_API Workspace {
|
|||
|
||||
static std::shared_ptr<Bookkeeper> bookkeeper();
|
||||
|
||||
BlobMap blob_map_;
|
||||
std::unordered_map<string, unique_ptr<Blob>> blob_map_;
|
||||
const string root_folder_;
|
||||
const Workspace* shared_;
|
||||
std::unordered_map<string, std::pair<const Workspace*, string>>
|
||||
|
|
@ -326,7 +332,7 @@ class TORCH_API Workspace {
|
|||
std::unique_ptr<ThreadPool> thread_pool_;
|
||||
std::mutex thread_pool_creation_mutex_;
|
||||
std::shared_ptr<Bookkeeper> bookkeeper_;
|
||||
NetMap net_map_;
|
||||
std::unordered_map<string, unique_ptr<NetBase>> net_map_;
|
||||
|
||||
C10_DISABLE_COPY_AND_ASSIGN(Workspace);
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user