mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Open registration for c10 thread pool (#17788)
Summary: 1. Move ATen threadpool & open registration mechanism to C10 2. Move the `global_work_queue` to use this open registration mechanism, to allow users to substitute in their own Pull Request resolved: https://github.com/pytorch/pytorch/pull/17788 Reviewed By: zdevito Differential Revision: D14379707 Pulled By: jamesr66a fbshipit-source-id: 949662d0024875abf09907d97db927f160c54d45
This commit is contained in:
parent
0955592243
commit
1d26a3ae7e
|
|
@ -1,5 +1,4 @@
|
|||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/thread_pool.h>
|
||||
#include <c10/core/thread_pool.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
|
|
@ -125,9 +124,32 @@ void setNumThreads(size_t v) {
|
|||
}
|
||||
}
|
||||
|
||||
ThreadPool& global_work_queue() {
|
||||
static ThreadPool thread_pool(num_threads.exchange(-1));
|
||||
return thread_pool;
|
||||
TaskThreadPoolBase& global_work_queue() {
|
||||
static std::shared_ptr<TaskThreadPoolBase> pool =
|
||||
ThreadPoolRegistry()->Create("C10", 0, num_threads.exchange(-1), false);
|
||||
return *pool;
|
||||
}
|
||||
|
||||
C10_DEFINE_SHARED_REGISTRY(
|
||||
ThreadPoolRegistry,
|
||||
TaskThreadPoolBase,
|
||||
int,
|
||||
int,
|
||||
bool);
|
||||
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<TaskThreadPoolBase> createC10ThreadPool(
|
||||
int device_id,
|
||||
int pool_size,
|
||||
bool create_new) {
|
||||
static std::shared_ptr<TaskThreadPoolBase> pool =
|
||||
std::make_shared<ThreadPool>(pool_size);
|
||||
return pool;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool);
|
||||
|
||||
} // namespace c10
|
||||
|
|
@ -9,6 +9,8 @@
|
|||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10/util/numa.h>
|
||||
#include <c10/util/thread_name.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
|
|
@ -17,7 +19,7 @@ struct Future;
|
|||
} // namespace ivalue
|
||||
|
||||
// TODO: move this to C10 and make it C10_API
|
||||
class CAFFE2_API TaskThreadPoolBase {
|
||||
class C10_API TaskThreadPoolBase {
|
||||
public:
|
||||
virtual void run(const std::function<void()>& func) = 0;
|
||||
|
||||
|
|
@ -36,7 +38,7 @@ class CAFFE2_API TaskThreadPoolBase {
|
|||
virtual ~TaskThreadPoolBase() noexcept {}
|
||||
};
|
||||
|
||||
class CAFFE2_API ThreadPool : public c10::TaskThreadPoolBase {
|
||||
class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
||||
protected:
|
||||
struct task_element_t {
|
||||
bool run_with_id;
|
||||
|
|
@ -100,8 +102,29 @@ class CAFFE2_API ThreadPool : public c10::TaskThreadPoolBase {
|
|||
void main_loop(std::size_t index);
|
||||
};
|
||||
|
||||
CAFFE2_API void setNumThreads(size_t v);
|
||||
C10_API void setNumThreads(size_t v);
|
||||
|
||||
CAFFE2_API ThreadPool& global_work_queue();
|
||||
C10_API TaskThreadPoolBase& global_work_queue();
|
||||
|
||||
class C10_API TaskThreadPool : public c10::ThreadPool {
|
||||
public:
|
||||
explicit TaskThreadPool(
|
||||
std::size_t pool_size,
|
||||
int numa_node_id = -1)
|
||||
: ThreadPool(pool_size, numa_node_id) {}
|
||||
|
||||
// TODO move this to ATen/core/thread_pool.h
|
||||
void init_thread() override {
|
||||
setThreadName("CaffeTaskThread");
|
||||
NUMABind(numa_node_id_);
|
||||
}
|
||||
};
|
||||
|
||||
C10_DECLARE_SHARED_REGISTRY(
|
||||
ThreadPoolRegistry,
|
||||
TaskThreadPoolBase,
|
||||
int,
|
||||
int,
|
||||
bool);
|
||||
|
||||
} // namespace c10
|
||||
|
|
@ -30,6 +30,20 @@
|
|||
// C10_BUILD_SHARED_LIB to check whether pytorch is building shared or static
|
||||
// libraries.
|
||||
|
||||
// For build systems that do not directly depend on CMake and directly build
|
||||
// from the source directory (such as Buck), one may not have a cmake_macros.h
|
||||
// file at all. In this case, the build system is responsible for providing
|
||||
// correct macro definitions corresponding to the cmake_macros.h.in file.
|
||||
//
|
||||
// In such scenarios, one should define the macro
|
||||
// C10_USING_CUSTOM_GENERATED_MACROS
|
||||
// to inform this header that it does not need to include the cmake_macros.h
|
||||
// file.
|
||||
|
||||
#ifndef C10_USING_CUSTOM_GENERATED_MACROS
|
||||
#include "c10/macros/cmake_macros.h"
|
||||
#endif // C10_USING_CUSTOM_GENERATED_MACROS
|
||||
|
||||
#ifdef _WIN32
|
||||
#if defined(C10_BUILD_SHARED_LIBS)
|
||||
#define C10_EXPORT __declspec(dllexport)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,19 @@
|
|||
#include "caffe2/utils/thread_name.h"
|
||||
#include "c10/util/thread_name.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(__GLIBC__) && !defined(__APPLE__) && !defined(__ANDROID__)
|
||||
#define CAFFE2_HAS_PTHREAD_SETNAME_NP
|
||||
#define C10_HAS_PTHREAD_SETNAME_NP
|
||||
#endif
|
||||
|
||||
#ifdef CAFFE2_HAS_PTHREAD_SETNAME_NP
|
||||
#ifdef C10_HAS_PTHREAD_SETNAME_NP
|
||||
#include <pthread.h>
|
||||
#endif
|
||||
|
||||
namespace caffe2 {
|
||||
namespace c10 {
|
||||
|
||||
void setThreadName(std::string name) {
|
||||
#ifdef CAFFE2_HAS_PTHREAD_SETNAME_NP
|
||||
#ifdef C10_HAS_PTHREAD_SETNAME_NP
|
||||
constexpr size_t kMaxThreadName = 15;
|
||||
name.resize(std::min(name.size(), kMaxThreadName));
|
||||
|
||||
|
|
@ -21,4 +21,4 @@ void setThreadName(std::string name) {
|
|||
#endif
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace c10
|
||||
11
c10/util/thread_name.h
Normal file
11
c10/util/thread_name.h
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "c10/macros/Export.h"
|
||||
|
||||
namespace c10 {
|
||||
|
||||
C10_API void setThreadName(std::string name);
|
||||
|
||||
} // namespace c10
|
||||
|
|
@ -9,6 +9,7 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "c10/core/thread_pool.h"
|
||||
#include "c10/util/Registry.h"
|
||||
#include "caffe2/core/blob.h"
|
||||
#include "caffe2/core/common.h"
|
||||
|
|
@ -19,7 +20,6 @@
|
|||
#include "caffe2/core/workspace.h"
|
||||
#include "caffe2/proto/caffe2_pb.h"
|
||||
#include "caffe2/utils/simple_queue.h"
|
||||
#include "caffe2/utils/thread_pool.h"
|
||||
|
||||
C10_DECLARE_string(caffe2_override_executor);
|
||||
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ TaskThreadPoolBase* AsyncNetBase::poolGetter(
|
|||
std::unique_lock<std::mutex> pools_lock(pools_mutex_);
|
||||
auto pool = pools[device_id][pool_size];
|
||||
if (!pool) {
|
||||
pool = ThreadPoolRegistry()->Create(
|
||||
pool = c10::ThreadPoolRegistry()->Create(
|
||||
DeviceTypeName(device_type),
|
||||
device_id,
|
||||
pool_size,
|
||||
|
|
@ -478,26 +478,6 @@ AsyncNetBase::~AsyncNetBase() {
|
|||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_SHARED_REGISTRY(
|
||||
ThreadPoolRegistry,
|
||||
TaskThreadPoolBase,
|
||||
int,
|
||||
int,
|
||||
bool);
|
||||
|
||||
C10_REGISTER_CREATOR(
|
||||
ThreadPoolRegistry,
|
||||
CPU,
|
||||
GetAsyncNetThreadPool<TaskThreadPool, PROTO_CPU>);
|
||||
C10_REGISTER_CREATOR(
|
||||
ThreadPoolRegistry,
|
||||
CUDA,
|
||||
GetAsyncNetThreadPool<TaskThreadPool, PROTO_CUDA>);
|
||||
C10_REGISTER_CREATOR(
|
||||
ThreadPoolRegistry,
|
||||
HIP,
|
||||
GetAsyncNetThreadPool<TaskThreadPool, PROTO_HIP>);
|
||||
|
||||
ExecutionOptions::ExecutionOptions(
|
||||
const std::shared_ptr<const NetDef>& net_def) {
|
||||
static const std::string kDag = "dag";
|
||||
|
|
@ -558,3 +538,20 @@ ExecutionOptions::ExecutionOptions(
|
|||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
namespace c10 {
|
||||
|
||||
C10_REGISTER_CREATOR(
|
||||
ThreadPoolRegistry,
|
||||
CPU,
|
||||
caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CPU>);
|
||||
C10_REGISTER_CREATOR(
|
||||
ThreadPoolRegistry,
|
||||
CUDA,
|
||||
caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CUDA>);
|
||||
C10_REGISTER_CREATOR(
|
||||
ThreadPoolRegistry,
|
||||
HIP,
|
||||
caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_HIP>);
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_
|
||||
#define CAFFE2_CORE_NET_ASYNC_BASE_H_
|
||||
|
||||
#include "c10/core/thread_pool.h"
|
||||
#include "c10/util/Registry.h"
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/net.h"
|
||||
|
|
@ -12,7 +13,6 @@
|
|||
#include "caffe2/proto/caffe2_pb.h"
|
||||
#include "caffe2/proto/prof_dag.pb.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
#include "caffe2/utils/thread_pool.h"
|
||||
|
||||
C10_DECLARE_int(caffe2_streams_per_gpu);
|
||||
C10_DECLARE_int(caffe2_net_async_max_gpus);
|
||||
|
|
@ -167,13 +167,6 @@ class CAFFE2_API AsyncNetBase : public NetBase {
|
|||
friend class tracing::Tracer;
|
||||
};
|
||||
|
||||
C10_DECLARE_SHARED_REGISTRY(
|
||||
ThreadPoolRegistry,
|
||||
TaskThreadPoolBase,
|
||||
int,
|
||||
int,
|
||||
bool);
|
||||
|
||||
class AsyncNetExecutorHelper : public ExecutorHelper {
|
||||
public:
|
||||
explicit AsyncNetExecutorHelper(AsyncNetBase* net) : net_(net) {}
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ TaskThreadPoolBase* ParallelNet::poolGetter(
|
|||
std::unique_lock<std::mutex> pools_lock(pools_mutex_);
|
||||
auto pool = pools[device_id][pool_size];
|
||||
if (!pool) {
|
||||
pool = ThreadPoolRegistry()->Create(
|
||||
pool = c10::ThreadPoolRegistry()->Create(
|
||||
DeviceTypeName(device_type),
|
||||
device_id,
|
||||
pool_size,
|
||||
|
|
|
|||
|
|
@ -7,14 +7,14 @@
|
|||
#include <iostream>
|
||||
#include <algorithm>
|
||||
|
||||
#include "c10/core/thread_pool.h"
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/image/transform_gpu.h"
|
||||
#include "caffe2/operators/prefetch_op.h"
|
||||
#include "caffe2/proto/caffe2_legacy.pb.h"
|
||||
#include "caffe2/utils/cast.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
#include "caffe2/utils/thread_pool.h"
|
||||
#include "caffe2/operators/prefetch_op.h"
|
||||
#include "caffe2/image/transform_gpu.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ list(APPEND Caffe2_CPU_SRCS
|
|||
utils/signal_handler.cc
|
||||
utils/smart_tensor_printer.cc
|
||||
utils/string_utils.cc
|
||||
utils/thread_name.cc
|
||||
utils/threadpool/ThreadPool.cc)
|
||||
|
||||
# ---[ threadpool/pthreadpool* is a local modification of the NNPACK
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
CAFFE2_API void setThreadName(std::string name);
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
#ifndef CAFFE2_UTILS_THREAD_POOL_H_
|
||||
#define CAFFE2_UTILS_THREAD_POOL_H_
|
||||
|
||||
#include "ATen/core/thread_pool.h"
|
||||
#include "caffe2/core/numa.h"
|
||||
#include "caffe2/utils/thread_name.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
class CAFFE2_API TaskThreadPool : public c10::ThreadPool {
|
||||
public:
|
||||
explicit TaskThreadPool(
|
||||
std::size_t pool_size,
|
||||
int numa_node_id = -1)
|
||||
: ThreadPool(pool_size, numa_node_id) {}
|
||||
|
||||
// TODO move this to ATen/core/thread_pool.h
|
||||
void init_thread() override {
|
||||
setThreadName("CaffeTaskThread");
|
||||
NUMABind(numa_node_id_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_UTILS_THREAD_POOL_H_
|
||||
|
|
@ -3,9 +3,9 @@
|
|||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <thread>
|
||||
#include "c10/util/thread_name.h"
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/utils/thread_name.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#include <intrin.h>
|
||||
|
|
@ -263,7 +263,7 @@ class alignas(kGEMMLOWPCacheLineSize) Worker {
|
|||
|
||||
// Thread entry point.
|
||||
void ThreadFunc() {
|
||||
setThreadName("CaffeWorkersPool");
|
||||
c10::setThreadName("CaffeWorkersPool");
|
||||
ChangeState(State::Ready);
|
||||
|
||||
// Thread main loop
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include <c10/core/thread_pool.h>
|
||||
#include <caffe2/core/db.h>
|
||||
#include <caffe2/core/logging.h>
|
||||
#include <caffe2/operators/prefetch_op.h>
|
||||
#include <caffe2/utils/math.h>
|
||||
#include <caffe2/utils/thread_pool.h>
|
||||
#include <caffe2/video/video_io.h>
|
||||
|
||||
namespace caffe2 {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/script/jit_exception.h>
|
||||
#include <ATen/core/thread_pool.h>
|
||||
#include <c10/core/thread_pool.h>
|
||||
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/thread_pool.h>
|
||||
#include <c10/core/thread_pool.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user