Open registration for c10 thread pool (#17788)

Summary:
1. Move ATen threadpool & open registration mechanism to C10
2. Move the `global_work_queue` to use this open registration mechanism, to allow users to substitute in their own
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17788

Reviewed By: zdevito

Differential Revision: D14379707

Pulled By: jamesr66a

fbshipit-source-id: 949662d0024875abf09907d97db927f160c54d45
This commit is contained in:
James Reed 2019-03-08 15:33:34 -08:00 committed by Facebook Github Bot
parent 0955592243
commit 1d26a3ae7e
17 changed files with 114 additions and 92 deletions

View File

@ -1,5 +1,4 @@
#include <ATen/core/ivalue.h>
#include <ATen/core/thread_pool.h>
#include <c10/core/thread_pool.h>
namespace c10 {
@ -125,9 +124,32 @@ void setNumThreads(size_t v) {
}
}
ThreadPool& global_work_queue() {
static ThreadPool thread_pool(num_threads.exchange(-1));
return thread_pool;
TaskThreadPoolBase& global_work_queue() {
static std::shared_ptr<TaskThreadPoolBase> pool =
ThreadPoolRegistry()->Create("C10", 0, num_threads.exchange(-1), false);
return *pool;
}
C10_DEFINE_SHARED_REGISTRY(
ThreadPoolRegistry,
TaskThreadPoolBase,
int,
int,
bool);
namespace {
std::shared_ptr<TaskThreadPoolBase> createC10ThreadPool(
int device_id,
int pool_size,
bool create_new) {
static std::shared_ptr<TaskThreadPoolBase> pool =
std::make_shared<ThreadPool>(pool_size);
return pool;
}
} // namespace
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool);
} // namespace c10

View File

@ -9,6 +9,8 @@
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/util/numa.h>
#include <c10/util/thread_name.h>
namespace c10 {
@ -17,7 +19,7 @@ struct Future;
} // namespace ivalue
// TODO: move this to C10 and make it C10_API
class CAFFE2_API TaskThreadPoolBase {
class C10_API TaskThreadPoolBase {
public:
virtual void run(const std::function<void()>& func) = 0;
@ -36,7 +38,7 @@ class CAFFE2_API TaskThreadPoolBase {
virtual ~TaskThreadPoolBase() noexcept {}
};
class CAFFE2_API ThreadPool : public c10::TaskThreadPoolBase {
class C10_API ThreadPool : public c10::TaskThreadPoolBase {
protected:
struct task_element_t {
bool run_with_id;
@ -100,8 +102,29 @@ class CAFFE2_API ThreadPool : public c10::TaskThreadPoolBase {
void main_loop(std::size_t index);
};
CAFFE2_API void setNumThreads(size_t v);
C10_API void setNumThreads(size_t v);
CAFFE2_API ThreadPool& global_work_queue();
C10_API TaskThreadPoolBase& global_work_queue();
class C10_API TaskThreadPool : public c10::ThreadPool {
public:
explicit TaskThreadPool(
std::size_t pool_size,
int numa_node_id = -1)
: ThreadPool(pool_size, numa_node_id) {}
// TODO move this to ATen/core/thread_pool.h
void init_thread() override {
setThreadName("CaffeTaskThread");
NUMABind(numa_node_id_);
}
};
C10_DECLARE_SHARED_REGISTRY(
ThreadPoolRegistry,
TaskThreadPoolBase,
int,
int,
bool);
} // namespace c10

View File

@ -30,6 +30,20 @@
// C10_BUILD_SHARED_LIB to check whether pytorch is building shared or static
// libraries.
// For build systems that do not directly depend on CMake and directly build
// from the source directory (such as Buck), one may not have a cmake_macros.h
// file at all. In this case, the build system is responsible for providing
// correct macro definitions corresponding to the cmake_macros.h.in file.
//
// In such scenarios, one should define the macro
// C10_USING_CUSTOM_GENERATED_MACROS
// to inform this header that it does not need to include the cmake_macros.h
// file.
#ifndef C10_USING_CUSTOM_GENERATED_MACROS
#include "c10/macros/cmake_macros.h"
#endif // C10_USING_CUSTOM_GENERATED_MACROS
#ifdef _WIN32
#if defined(C10_BUILD_SHARED_LIBS)
#define C10_EXPORT __declspec(dllexport)

View File

@ -1,19 +1,19 @@
#include "caffe2/utils/thread_name.h"
#include "c10/util/thread_name.h"
#include <algorithm>
#if defined(__GLIBC__) && !defined(__APPLE__) && !defined(__ANDROID__)
#define CAFFE2_HAS_PTHREAD_SETNAME_NP
#define C10_HAS_PTHREAD_SETNAME_NP
#endif
#ifdef CAFFE2_HAS_PTHREAD_SETNAME_NP
#ifdef C10_HAS_PTHREAD_SETNAME_NP
#include <pthread.h>
#endif
namespace caffe2 {
namespace c10 {
void setThreadName(std::string name) {
#ifdef CAFFE2_HAS_PTHREAD_SETNAME_NP
#ifdef C10_HAS_PTHREAD_SETNAME_NP
constexpr size_t kMaxThreadName = 15;
name.resize(std::min(name.size(), kMaxThreadName));
@ -21,4 +21,4 @@ void setThreadName(std::string name) {
#endif
}
} // namespace caffe2
} // namespace c10

11
c10/util/thread_name.h Normal file
View File

@ -0,0 +1,11 @@
#pragma once
#include <string>
#include "c10/macros/Export.h"
namespace c10 {
C10_API void setThreadName(std::string name);
} // namespace c10

View File

@ -9,6 +9,7 @@
#include <unordered_map>
#include <vector>
#include "c10/core/thread_pool.h"
#include "c10/util/Registry.h"
#include "caffe2/core/blob.h"
#include "caffe2/core/common.h"
@ -19,7 +20,6 @@
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/simple_queue.h"
#include "caffe2/utils/thread_pool.h"
C10_DECLARE_string(caffe2_override_executor);

View File

@ -151,7 +151,7 @@ TaskThreadPoolBase* AsyncNetBase::poolGetter(
std::unique_lock<std::mutex> pools_lock(pools_mutex_);
auto pool = pools[device_id][pool_size];
if (!pool) {
pool = ThreadPoolRegistry()->Create(
pool = c10::ThreadPoolRegistry()->Create(
DeviceTypeName(device_type),
device_id,
pool_size,
@ -478,26 +478,6 @@ AsyncNetBase::~AsyncNetBase() {
}
}
C10_DEFINE_SHARED_REGISTRY(
ThreadPoolRegistry,
TaskThreadPoolBase,
int,
int,
bool);
C10_REGISTER_CREATOR(
ThreadPoolRegistry,
CPU,
GetAsyncNetThreadPool<TaskThreadPool, PROTO_CPU>);
C10_REGISTER_CREATOR(
ThreadPoolRegistry,
CUDA,
GetAsyncNetThreadPool<TaskThreadPool, PROTO_CUDA>);
C10_REGISTER_CREATOR(
ThreadPoolRegistry,
HIP,
GetAsyncNetThreadPool<TaskThreadPool, PROTO_HIP>);
ExecutionOptions::ExecutionOptions(
const std::shared_ptr<const NetDef>& net_def) {
static const std::string kDag = "dag";
@ -558,3 +538,20 @@ ExecutionOptions::ExecutionOptions(
}
} // namespace caffe2
namespace c10 {
C10_REGISTER_CREATOR(
ThreadPoolRegistry,
CPU,
caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CPU>);
C10_REGISTER_CREATOR(
ThreadPoolRegistry,
CUDA,
caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CUDA>);
C10_REGISTER_CREATOR(
ThreadPoolRegistry,
HIP,
caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_HIP>);
} // namespace c10

View File

@ -1,6 +1,7 @@
#ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_
#define CAFFE2_CORE_NET_ASYNC_BASE_H_
#include "c10/core/thread_pool.h"
#include "c10/util/Registry.h"
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
@ -12,7 +13,6 @@
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/prof_dag.pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/thread_pool.h"
C10_DECLARE_int(caffe2_streams_per_gpu);
C10_DECLARE_int(caffe2_net_async_max_gpus);
@ -167,13 +167,6 @@ class CAFFE2_API AsyncNetBase : public NetBase {
friend class tracing::Tracer;
};
C10_DECLARE_SHARED_REGISTRY(
ThreadPoolRegistry,
TaskThreadPoolBase,
int,
int,
bool);
class AsyncNetExecutorHelper : public ExecutorHelper {
public:
explicit AsyncNetExecutorHelper(AsyncNetBase* net) : net_(net) {}

View File

@ -132,7 +132,7 @@ TaskThreadPoolBase* ParallelNet::poolGetter(
std::unique_lock<std::mutex> pools_lock(pools_mutex_);
auto pool = pools[device_id][pool_size];
if (!pool) {
pool = ThreadPoolRegistry()->Create(
pool = c10::ThreadPoolRegistry()->Create(
DeviceTypeName(device_type),
device_id,
pool_size,

View File

@ -7,14 +7,14 @@
#include <iostream>
#include <algorithm>
#include "c10/core/thread_pool.h"
#include "caffe2/core/common.h"
#include "caffe2/core/db.h"
#include "caffe2/image/transform_gpu.h"
#include "caffe2/operators/prefetch_op.h"
#include "caffe2/proto/caffe2_legacy.pb.h"
#include "caffe2/utils/cast.h"
#include "caffe2/utils/math.h"
#include "caffe2/utils/thread_pool.h"
#include "caffe2/operators/prefetch_op.h"
#include "caffe2/image/transform_gpu.h"
namespace caffe2 {

View File

@ -13,7 +13,6 @@ list(APPEND Caffe2_CPU_SRCS
utils/signal_handler.cc
utils/smart_tensor_printer.cc
utils/string_utils.cc
utils/thread_name.cc
utils/threadpool/ThreadPool.cc)
# ---[ threadpool/pthreadpool* is a local modification of the NNPACK

View File

@ -1,11 +0,0 @@
#pragma once
#include <string>
#include "caffe2/core/common.h"
namespace caffe2 {
CAFFE2_API void setThreadName(std::string name);
} // namespace caffe2

View File

@ -1,26 +0,0 @@
#ifndef CAFFE2_UTILS_THREAD_POOL_H_
#define CAFFE2_UTILS_THREAD_POOL_H_
#include "ATen/core/thread_pool.h"
#include "caffe2/core/numa.h"
#include "caffe2/utils/thread_name.h"
namespace caffe2 {
class CAFFE2_API TaskThreadPool : public c10::ThreadPool {
public:
explicit TaskThreadPool(
std::size_t pool_size,
int numa_node_id = -1)
: ThreadPool(pool_size, numa_node_id) {}
// TODO move this to ATen/core/thread_pool.h
void init_thread() override {
setThreadName("CaffeTaskThread");
NUMABind(numa_node_id_);
}
};
} // namespace caffe2
#endif // CAFFE2_UTILS_THREAD_POOL_H_

View File

@ -3,9 +3,9 @@
#include <atomic>
#include <condition_variable>
#include <thread>
#include "c10/util/thread_name.h"
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/utils/thread_name.h"
#if defined(_MSC_VER)
#include <intrin.h>
@ -263,7 +263,7 @@ class alignas(kGEMMLOWPCacheLineSize) Worker {
// Thread entry point.
void ThreadFunc() {
setThreadName("CaffeWorkersPool");
c10::setThreadName("CaffeWorkersPool");
ChangeState(State::Ready);
// Thread main loop

View File

@ -6,11 +6,11 @@
#include <random>
#include <string>
#include <c10/core/thread_pool.h>
#include <caffe2/core/db.h>
#include <caffe2/core/logging.h>
#include <caffe2/operators/prefetch_op.h>
#include <caffe2/utils/math.h>
#include <caffe2/utils/thread_pool.h>
#include <caffe2/video/video_io.h>
namespace caffe2 {

View File

@ -13,7 +13,7 @@
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/jit_exception.h>
#include <ATen/core/thread_pool.h>
#include <c10/core/thread_pool.h>
#include <exception>
#include <iostream>

View File

@ -14,7 +14,7 @@
#include <ATen/ExpandUtils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/thread_pool.h>
#include <c10/core/thread_pool.h>
#include <c10/util/SmallVector.h>
#include <algorithm>