mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[caffe2] Avoid some double (and triple) lookups in workspace (#53319)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53319 Noticed these in profiles. Also switch to `unordered_map`. Test Plan: Unit tests. Reviewed By: swolchok Differential Revision: D26504408 fbshipit-source-id: 9e14d55909a4af019058b8c27c67ee2348cd02a9
This commit is contained in:
parent
35364c3641
commit
69bb0e0285
|
|
@ -84,7 +84,7 @@ vector<string> Workspace::Blobs() const {
|
||||||
names.push_back(entry.first);
|
names.push_back(entry.first);
|
||||||
}
|
}
|
||||||
for (const auto& forwarded : forwarded_blobs_) {
|
for (const auto& forwarded : forwarded_blobs_) {
|
||||||
const auto parent_ws = forwarded.second.first;
|
const auto* parent_ws = forwarded.second.first;
|
||||||
const auto& parent_name = forwarded.second.second;
|
const auto& parent_name = forwarded.second.second;
|
||||||
if (parent_ws->HasBlob(parent_name)) {
|
if (parent_ws->HasBlob(parent_name)) {
|
||||||
names.push_back(forwarded.first);
|
names.push_back(forwarded.first);
|
||||||
|
|
@ -112,13 +112,14 @@ Blob* Workspace::CreateBlob(const string& name) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Blob* Workspace::CreateLocalBlob(const string& name) {
|
Blob* Workspace::CreateLocalBlob(const string& name) {
|
||||||
if (blob_map_.count(name)) {
|
auto p = blob_map_.emplace(name, nullptr);
|
||||||
|
if (!p.second) {
|
||||||
VLOG(1) << "Blob " << name << " already exists. Skipping.";
|
VLOG(1) << "Blob " << name << " already exists. Skipping.";
|
||||||
} else {
|
} else {
|
||||||
VLOG(1) << "Creating blob " << name;
|
VLOG(1) << "Creating blob " << name;
|
||||||
blob_map_[name] = unique_ptr<Blob>(new Blob());
|
p.first->second = std::make_unique<Blob>();
|
||||||
}
|
}
|
||||||
return GetBlob(name);
|
return p.first->second.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
Blob* Workspace::RenameBlob(const string& old_name, const string& new_name) {
|
Blob* Workspace::RenameBlob(const string& old_name, const string& new_name) {
|
||||||
|
|
@ -158,15 +159,28 @@ bool Workspace::RemoveBlob(const string& name) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const Blob* Workspace::GetBlob(const string& name) const {
|
const Blob* Workspace::GetBlob(const string& name) const {
|
||||||
if (blob_map_.count(name)) {
|
{
|
||||||
return blob_map_.at(name).get();
|
auto it = blob_map_.find(name);
|
||||||
} else if (forwarded_blobs_.count(name)) {
|
if (it != blob_map_.end()) {
|
||||||
const auto parent_ws = forwarded_blobs_.at(name).first;
|
return it->second.get();
|
||||||
const auto& parent_name = forwarded_blobs_.at(name).second;
|
|
||||||
return parent_ws->GetBlob(parent_name);
|
|
||||||
} else if (shared_ && shared_->HasBlob(name)) {
|
|
||||||
return shared_->GetBlob(name);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto it = forwarded_blobs_.find(name);
|
||||||
|
if (it != forwarded_blobs_.end()) {
|
||||||
|
const auto* parent_ws = it->second.first;
|
||||||
|
const auto& parent_name = it->second.second;
|
||||||
|
return parent_ws->GetBlob(parent_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shared_) {
|
||||||
|
if (auto blob = shared_->GetBlob(name)) {
|
||||||
|
return blob;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LOG(WARNING) << "Blob " << name << " not in the workspace.";
|
LOG(WARNING) << "Blob " << name << " not in the workspace.";
|
||||||
// TODO(Yangqing): do we want to always print out the list of blobs here?
|
// TODO(Yangqing): do we want to always print out the list of blobs here?
|
||||||
// LOG(WARNING) << "Current blobs:";
|
// LOG(WARNING) << "Current blobs:";
|
||||||
|
|
@ -249,25 +263,25 @@ NetBase* Workspace::CreateNet(
|
||||||
}
|
}
|
||||||
|
|
||||||
NetBase* Workspace::GetNet(const string& name) {
|
NetBase* Workspace::GetNet(const string& name) {
|
||||||
if (!net_map_.count(name)) {
|
auto it = net_map_.find(name);
|
||||||
return nullptr;
|
if (it != net_map_.end()) {
|
||||||
} else {
|
return it->second.get();
|
||||||
return net_map_[name].get();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Workspace::DeleteNet(const string& name) {
|
void Workspace::DeleteNet(const string& name) {
|
||||||
if (net_map_.count(name)) {
|
|
||||||
net_map_.erase(name);
|
net_map_.erase(name);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Workspace::RunNet(const string& name) {
|
bool Workspace::RunNet(const string& name) {
|
||||||
if (!net_map_.count(name)) {
|
auto it = net_map_.find(name);
|
||||||
|
if (it == net_map_.end()) {
|
||||||
LOG(ERROR) << "Network " << name << " does not exist yet.";
|
LOG(ERROR) << "Network " << name << " does not exist yet.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return net_map_[name]->Run();
|
return it->second->Run();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Workspace::RunOperatorOnce(const OperatorDef& op_def) {
|
bool Workspace::RunOperatorOnce(const OperatorDef& op_def) {
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <typeinfo>
|
#include <typeinfo>
|
||||||
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -47,8 +48,6 @@ struct TORCH_API StopOnSignal {
|
||||||
class TORCH_API Workspace {
|
class TORCH_API Workspace {
|
||||||
public:
|
public:
|
||||||
typedef std::function<bool(int)> ShouldContinue;
|
typedef std::function<bool(int)> ShouldContinue;
|
||||||
typedef CaffeMap<string, unique_ptr<Blob> > BlobMap;
|
|
||||||
typedef CaffeMap<string, unique_ptr<NetBase> > NetMap;
|
|
||||||
/**
|
/**
|
||||||
* Initializes an empty workspace.
|
* Initializes an empty workspace.
|
||||||
*/
|
*/
|
||||||
|
|
@ -136,10 +135,11 @@ class TORCH_API Workspace {
|
||||||
template <class Context>
|
template <class Context>
|
||||||
void CopyForwardedTensors(const std::unordered_set<std::string>& blobs) {
|
void CopyForwardedTensors(const std::unordered_set<std::string>& blobs) {
|
||||||
for (const auto& blob : blobs) {
|
for (const auto& blob : blobs) {
|
||||||
if (!forwarded_blobs_.count(blob)) {
|
auto it = forwarded_blobs_.find(blob);
|
||||||
|
if (it == forwarded_blobs_.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const auto& ws_blob = forwarded_blobs_[blob];
|
const auto& ws_blob = it->second;
|
||||||
const auto* parent_ws = ws_blob.first;
|
const auto* parent_ws = ws_blob.first;
|
||||||
auto* from_blob = parent_ws->GetBlob(ws_blob.second);
|
auto* from_blob = parent_ws->GetBlob(ws_blob.second);
|
||||||
CAFFE_ENFORCE(from_blob);
|
CAFFE_ENFORCE(from_blob);
|
||||||
|
|
@ -181,13 +181,19 @@ class TORCH_API Workspace {
|
||||||
// Then, check the forwarding map, then the parent workspace
|
// Then, check the forwarding map, then the parent workspace
|
||||||
if (blob_map_.count(name)) {
|
if (blob_map_.count(name)) {
|
||||||
return true;
|
return true;
|
||||||
} else if (forwarded_blobs_.count(name)) {
|
}
|
||||||
const auto parent_ws = forwarded_blobs_.at(name).first;
|
|
||||||
const auto& parent_name = forwarded_blobs_.at(name).second;
|
auto it = forwarded_blobs_.find(name);
|
||||||
|
if (it != forwarded_blobs_.end()) {
|
||||||
|
const auto parent_ws = it->second.first;
|
||||||
|
const auto& parent_name = it->second.second;
|
||||||
return parent_ws->HasBlob(parent_name);
|
return parent_ws->HasBlob(parent_name);
|
||||||
} else if (shared_) {
|
}
|
||||||
|
|
||||||
|
if (shared_) {
|
||||||
return shared_->HasBlob(name);
|
return shared_->HasBlob(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -318,7 +324,7 @@ class TORCH_API Workspace {
|
||||||
|
|
||||||
static std::shared_ptr<Bookkeeper> bookkeeper();
|
static std::shared_ptr<Bookkeeper> bookkeeper();
|
||||||
|
|
||||||
BlobMap blob_map_;
|
std::unordered_map<string, unique_ptr<Blob>> blob_map_;
|
||||||
const string root_folder_;
|
const string root_folder_;
|
||||||
const Workspace* shared_;
|
const Workspace* shared_;
|
||||||
std::unordered_map<string, std::pair<const Workspace*, string>>
|
std::unordered_map<string, std::pair<const Workspace*, string>>
|
||||||
|
|
@ -326,7 +332,7 @@ class TORCH_API Workspace {
|
||||||
std::unique_ptr<ThreadPool> thread_pool_;
|
std::unique_ptr<ThreadPool> thread_pool_;
|
||||||
std::mutex thread_pool_creation_mutex_;
|
std::mutex thread_pool_creation_mutex_;
|
||||||
std::shared_ptr<Bookkeeper> bookkeeper_;
|
std::shared_ptr<Bookkeeper> bookkeeper_;
|
||||||
NetMap net_map_;
|
std::unordered_map<string, unique_ptr<NetBase>> net_map_;
|
||||||
|
|
||||||
C10_DISABLE_COPY_AND_ASSIGN(Workspace);
|
C10_DISABLE_COPY_AND_ASSIGN(Workspace);
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user