mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11779 Pull Request resolved: https://github.com/pytorch/pytorch/pull/11731 This Predictor provides threadsafe interface and also cleans-up activations after each run. So in multi-model setup activation space doesn't explode Reviewed By: highker Differential Revision: D9842374 fbshipit-source-id: bfe253ae5fc813e73a347c5147ff6b58d50781ea
80 lines
2.0 KiB
C++
80 lines
2.0 KiB
C++
#include "ThreadLocalPtr.h"
|
|
#include <algorithm>
|
|
|
|
namespace caffe2 {
|
|
|
|
// meyer's singleton
|
|
AllThreadLocalHelperVector* getAllThreadLocalHelperVector() {
|
|
// leak the pointer to avoid dealing with destruction order issues
|
|
static auto* instance = new AllThreadLocalHelperVector();
|
|
return instance;
|
|
}
|
|
|
|
ThreadLocalHelper* getThreadLocalHelper() {
|
|
static thread_local ThreadLocalHelper instance;
|
|
return &instance;
|
|
}
|
|
|
|
// AllThreadLocalHelperVector
|
|
|
|
void AllThreadLocalHelperVector::push_back(ThreadLocalHelper* helper) {
|
|
std::lock_guard<std::mutex> lg(mutex_);
|
|
vector_.push_back(helper);
|
|
}
|
|
|
|
void AllThreadLocalHelperVector::erase(ThreadLocalHelper* helper) {
|
|
std::lock_guard<std::mutex> lg(mutex_);
|
|
vector_.erase(
|
|
std::remove(vector_.begin(), vector_.end(), helper), vector_.end());
|
|
}
|
|
|
|
void AllThreadLocalHelperVector::erase_tlp(ThreadLocalPtrImpl* ptr) {
|
|
std::lock_guard<std::mutex> lg(mutex_);
|
|
for (auto* ins : vector_) {
|
|
ins->erase(ptr);
|
|
}
|
|
}
|
|
|
|
// ThreadLocalHelper
|
|
ThreadLocalHelper::ThreadLocalHelper() {
|
|
getAllThreadLocalHelperVector()->push_back(this);
|
|
}
|
|
|
|
ThreadLocalHelper::~ThreadLocalHelper() {
|
|
getAllThreadLocalHelperVector()->erase(this);
|
|
}
|
|
|
|
void ThreadLocalHelper::insert(
|
|
ThreadLocalPtrImpl* tl_ptr,
|
|
std::shared_ptr<void> ptr) {
|
|
std::lock_guard<std::mutex> lg(mutex_);
|
|
mapping_.insert(std::make_pair(tl_ptr, std::move(ptr)));
|
|
}
|
|
|
|
void* ThreadLocalHelper::get(ThreadLocalPtrImpl* key) {
|
|
/* Grabbing the mutex for the thread local map protecting the case
|
|
* when other object exits(~ThreadLocalPtrImpl()), and removes the
|
|
* element in the map, which will change the iterator returned
|
|
* by find.
|
|
*/
|
|
std::lock_guard<std::mutex> lg(mutex_);
|
|
auto it = mapping_.find(key);
|
|
|
|
if (it == mapping_.end()) {
|
|
return nullptr;
|
|
} else {
|
|
return it->second.get();
|
|
}
|
|
}
|
|
|
|
void ThreadLocalHelper::erase(ThreadLocalPtrImpl* key) {
|
|
std::lock_guard<std::mutex> lg(mutex_);
|
|
mapping_.erase(key);
|
|
}
|
|
|
|
ThreadLocalPtrImpl::~ThreadLocalPtrImpl() {
|
|
getAllThreadLocalHelperVector()->erase_tlp(this);
|
|
}
|
|
|
|
} // namespace caffe2
|