pytorch/caffe2/predictor/ThreadLocalPtr.h
Alexander Sidorov d522b3ca58 BlackBoxPredictor OSS part N: ThreadLocalPtr, InferenceGraph (#23257)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23257

Overal context: open-source BlackBoxPredictor as the entry
point for inference in Caffe2 (thread safe abstraction for Caffe2
inference). This should be used in ThroughputBenchmark for the purpose
of framework comparison
This specific diff:
There should be no harm in moving transformation code to
OSS. On the advantages side we will be able to compare production
Caffe2 setup with PyTorch in the most fair way via
ThroughputBenchmark. This approach avoid any complicated
transformation regirstries. Building those proper would be significant
engineering effort as well as production risk. In the past we had SEVs
related to transforms being turned off due to various refactors. Given
that we don't plan to build any other significant investments into
transformation logic except existing ones (like TVM and Glow), and
those also relate to open-source technologies, I came up to the
conclusion of moving to OSS the whole thing.

Reviewed By: zrphercule

Differential Revision: D16428124

fbshipit-source-id: b35deada5c015cd97b91ae12a7ea4aac53bd14b8
2019-07-24 14:35:30 -07:00

159 lines
4.0 KiB
C++

#pragma once
#include <mutex>
#include <unordered_map>
#include <unordered_set>
#include "caffe2/core/logging.h"
namespace caffe2 {
/**
* thread_local pointer in C++ is a per thread pointer. However, sometimes
* we want to have a thread local state that is per thread and also per
* instance. e.g. we have the following class:
* class A {
* ThreadLocalPtr<int> x;
* }
* We would like to have a copy of x per thread and also per instance of class A
* This can be applied to storing per instance thread local state of some class,
* when we could have multiple instances of the class in the same thread.
* We implemented a subset of functions in folly::ThreadLocalPtr that's enough
* to support BlackBoxPredictor.
*/
class ThreadLocalPtrImpl;
class ThreadLocalHelper;
/**
* Map of object pointer to instance in each thread
* to achieve per thread(using thread_local) per object(using the map)
* thread local pointer
*/
typedef std::unordered_map<ThreadLocalPtrImpl*, std::shared_ptr<void>>
UnsafeThreadLocalMap;
ThreadLocalHelper* getThreadLocalHelper();
typedef std::vector<ThreadLocalHelper*> UnsafeAllThreadLocalHelperVector;
/**
* A thread safe vector of all ThreadLocalHelper, this will be used
* to encapuslate the locking in the APIs for the changes to the global
* AllThreadLocalHelperVector instance.
*/
class AllThreadLocalHelperVector {
public:
AllThreadLocalHelperVector() {}
// Add a new ThreadLocalHelper to the vector
void push_back(ThreadLocalHelper* helper);
// Erase a ThreadLocalHelper to the vector
void erase(ThreadLocalHelper* helper);
// Erase object in all the helpers stored in vector
// Called during destructor of a ThreadLocalPtrImpl
void erase_tlp(ThreadLocalPtrImpl* ptr);
private:
UnsafeAllThreadLocalHelperVector vector_;
std::mutex mutex_;
};
/**
* ThreadLocalHelper is per thread
*/
class ThreadLocalHelper {
public:
ThreadLocalHelper();
// When the thread dies, we want to clean up *this*
// in AllThreadLocalHelperVector
~ThreadLocalHelper();
// Insert a (object, ptr) pair into the thread local map
void insert(ThreadLocalPtrImpl* tl_ptr, std::shared_ptr<void> ptr);
// Get the ptr by object
void* get(ThreadLocalPtrImpl* key);
// Erase the ptr associated with the object in the map
void erase(ThreadLocalPtrImpl* key);
private:
// mapping of object -> ptr in each thread
UnsafeThreadLocalMap mapping_;
std::mutex mutex_;
}; // ThreadLocalHelper
/** ThreadLocalPtrImpl is per object
*/
class ThreadLocalPtrImpl {
public:
ThreadLocalPtrImpl() {}
// Delete copy and move constructors
ThreadLocalPtrImpl(const ThreadLocalPtrImpl&) = delete;
ThreadLocalPtrImpl(ThreadLocalPtrImpl&&) = delete;
ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&) = delete;
ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&&) = delete;
// In the case when object dies first, we want to
// clean up the states in all child threads
~ThreadLocalPtrImpl();
template <typename T>
T* get() {
return static_cast<T*>(getThreadLocalHelper()->get(this));
}
template <typename T>
void reset(T* newPtr = nullptr) {
VLOG(2) << "In Reset(" << newPtr << ")";
auto* wrapper = getThreadLocalHelper();
// Cleaning up the objects(T) stored in the ThreadLocalPtrImpl in the thread
wrapper->erase(this);
if (newPtr != nullptr) {
std::shared_ptr<void> sharedPtr(newPtr);
// Deletion of newPtr is handled by shared_ptr
// as it implements type erasure
wrapper->insert(this, std::move(sharedPtr));
}
}
}; // ThreadLocalPtrImpl
template <typename T>
class ThreadLocalPtr {
public:
auto* operator-> () {
return get();
}
auto& operator*() {
return *get();
}
auto* get() {
return impl_.get<T>();
}
auto* operator-> () const {
return get();
}
auto& operator*() const {
return *get();
}
auto* get() const {
return impl_.get<T>();
}
void reset(unique_ptr<T> ptr = nullptr) {
impl_.reset<T>(ptr.release());
}
private:
ThreadLocalPtrImpl impl_;
};
} // namespace caffe2