mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23257 Overal context: open-source BlackBoxPredictor as the entry point for inference in Caffe2 (thread safe abstraction for Caffe2 inference). This should be used in ThroughputBenchmark for the purpose of framework comparison This specific diff: There should be no harm in moving transformation code to OSS. On the advantages side we will be able to compare production Caffe2 setup with PyTorch in the most fair way via ThroughputBenchmark. This approach avoid any complicated transformation regirstries. Building those proper would be significant engineering effort as well as production risk. In the past we had SEVs related to transforms being turned off due to various refactors. Given that we don't plan to build any other significant investments into transformation logic except existing ones (like TVM and Glow), and those also relate to open-source technologies, I came up to the conclusion of moving to OSS the whole thing. Reviewed By: zrphercule Differential Revision: D16428124 fbshipit-source-id: b35deada5c015cd97b91ae12a7ea4aac53bd14b8
159 lines
4.0 KiB
C++
159 lines
4.0 KiB
C++
#pragma once
|
|
|
|
#include <mutex>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include "caffe2/core/logging.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
/**
|
|
* thread_local pointer in C++ is a per thread pointer. However, sometimes
|
|
* we want to have a thread local state that is per thread and also per
|
|
* instance. e.g. we have the following class:
|
|
* class A {
|
|
* ThreadLocalPtr<int> x;
|
|
* }
|
|
* We would like to have a copy of x per thread and also per instance of class A
|
|
* This can be applied to storing per instance thread local state of some class,
|
|
* when we could have multiple instances of the class in the same thread.
|
|
* We implemented a subset of functions in folly::ThreadLocalPtr that's enough
|
|
* to support BlackBoxPredictor.
|
|
*/
|
|
|
|
class ThreadLocalPtrImpl;
|
|
class ThreadLocalHelper;
|
|
|
|
/**
|
|
* Map of object pointer to instance in each thread
|
|
* to achieve per thread(using thread_local) per object(using the map)
|
|
* thread local pointer
|
|
*/
|
|
typedef std::unordered_map<ThreadLocalPtrImpl*, std::shared_ptr<void>>
|
|
UnsafeThreadLocalMap;
|
|
|
|
ThreadLocalHelper* getThreadLocalHelper();
|
|
|
|
typedef std::vector<ThreadLocalHelper*> UnsafeAllThreadLocalHelperVector;
|
|
|
|
/**
|
|
* A thread safe vector of all ThreadLocalHelper, this will be used
|
|
* to encapuslate the locking in the APIs for the changes to the global
|
|
* AllThreadLocalHelperVector instance.
|
|
*/
|
|
class AllThreadLocalHelperVector {
|
|
public:
|
|
AllThreadLocalHelperVector() {}
|
|
|
|
// Add a new ThreadLocalHelper to the vector
|
|
void push_back(ThreadLocalHelper* helper);
|
|
|
|
// Erase a ThreadLocalHelper to the vector
|
|
void erase(ThreadLocalHelper* helper);
|
|
|
|
// Erase object in all the helpers stored in vector
|
|
// Called during destructor of a ThreadLocalPtrImpl
|
|
void erase_tlp(ThreadLocalPtrImpl* ptr);
|
|
|
|
private:
|
|
UnsafeAllThreadLocalHelperVector vector_;
|
|
std::mutex mutex_;
|
|
};
|
|
|
|
/**
|
|
* ThreadLocalHelper is per thread
|
|
*/
|
|
class ThreadLocalHelper {
|
|
public:
|
|
ThreadLocalHelper();
|
|
|
|
// When the thread dies, we want to clean up *this*
|
|
// in AllThreadLocalHelperVector
|
|
~ThreadLocalHelper();
|
|
|
|
// Insert a (object, ptr) pair into the thread local map
|
|
void insert(ThreadLocalPtrImpl* tl_ptr, std::shared_ptr<void> ptr);
|
|
// Get the ptr by object
|
|
void* get(ThreadLocalPtrImpl* key);
|
|
// Erase the ptr associated with the object in the map
|
|
void erase(ThreadLocalPtrImpl* key);
|
|
|
|
private:
|
|
// mapping of object -> ptr in each thread
|
|
UnsafeThreadLocalMap mapping_;
|
|
std::mutex mutex_;
|
|
}; // ThreadLocalHelper
|
|
|
|
/** ThreadLocalPtrImpl is per object
|
|
*/
|
|
class ThreadLocalPtrImpl {
|
|
public:
|
|
ThreadLocalPtrImpl() {}
|
|
// Delete copy and move constructors
|
|
ThreadLocalPtrImpl(const ThreadLocalPtrImpl&) = delete;
|
|
ThreadLocalPtrImpl(ThreadLocalPtrImpl&&) = delete;
|
|
ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&) = delete;
|
|
ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&&) = delete;
|
|
|
|
// In the case when object dies first, we want to
|
|
// clean up the states in all child threads
|
|
~ThreadLocalPtrImpl();
|
|
|
|
template <typename T>
|
|
T* get() {
|
|
return static_cast<T*>(getThreadLocalHelper()->get(this));
|
|
}
|
|
|
|
template <typename T>
|
|
void reset(T* newPtr = nullptr) {
|
|
VLOG(2) << "In Reset(" << newPtr << ")";
|
|
auto* wrapper = getThreadLocalHelper();
|
|
// Cleaning up the objects(T) stored in the ThreadLocalPtrImpl in the thread
|
|
wrapper->erase(this);
|
|
if (newPtr != nullptr) {
|
|
std::shared_ptr<void> sharedPtr(newPtr);
|
|
// Deletion of newPtr is handled by shared_ptr
|
|
// as it implements type erasure
|
|
wrapper->insert(this, std::move(sharedPtr));
|
|
}
|
|
}
|
|
|
|
}; // ThreadLocalPtrImpl
|
|
|
|
template <typename T>
|
|
class ThreadLocalPtr {
|
|
public:
|
|
auto* operator-> () {
|
|
return get();
|
|
}
|
|
|
|
auto& operator*() {
|
|
return *get();
|
|
}
|
|
|
|
auto* get() {
|
|
return impl_.get<T>();
|
|
}
|
|
|
|
auto* operator-> () const {
|
|
return get();
|
|
}
|
|
|
|
auto& operator*() const {
|
|
return *get();
|
|
}
|
|
|
|
auto* get() const {
|
|
return impl_.get<T>();
|
|
}
|
|
|
|
void reset(unique_ptr<T> ptr = nullptr) {
|
|
impl_.reset<T>(ptr.release());
|
|
}
|
|
|
|
private:
|
|
ThreadLocalPtrImpl impl_;
|
|
};
|
|
|
|
} // namespace caffe2
|