diff --git a/aten/src/ATen/ParallelNative.cpp b/aten/src/ATen/ParallelNative.cpp index a2e19926500..7136630322b 100644 --- a/aten/src/ATen/ParallelNative.cpp +++ b/aten/src/ATen/ParallelNative.cpp @@ -209,8 +209,8 @@ void init_num_threads() { } void set_num_threads(int nthreads) { -#ifndef C10_MOBILE TORCH_CHECK(nthreads > 0, "Expected positive number of threads"); +#ifndef C10_MOBILE int no_value = NOT_SET; if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) { // num_intraop_threads either stores a positive integer or CONSUMED, @@ -229,9 +229,8 @@ void set_num_threads(int nthreads) { } } #else - caffe2::PThreadPool* const pool = caffe2::pthreadpool(); + caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads); TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); - pool->set_thread_count(nthreads); #endif // C10_MOBILE } diff --git a/aten/src/ATen/ParallelOpenMP.cpp b/aten/src/ATen/ParallelOpenMP.cpp index 40257882ea2..1c128bfc3b2 100644 --- a/aten/src/ATen/ParallelOpenMP.cpp +++ b/aten/src/ATen/ParallelOpenMP.cpp @@ -61,9 +61,8 @@ void set_num_threads(int nthreads) { #endif #ifdef USE_PTHREADPOOL // because PyTorch uses caffe2::pthreadpool() in QNNPACK - caffe2::PThreadPool* const pool = caffe2::pthreadpool(); + caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads); TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); - pool->set_thread_count(nthreads); #endif #if AT_MKLDNN_ENABLED() at::native::mkldnn::clear_computation_cache(); diff --git a/caffe2/utils/threadpool/pthreadpool-cpp.cc b/caffe2/utils/threadpool/pthreadpool-cpp.cc index e281fa2cb40..6766b13d2b8 100644 --- a/caffe2/utils/threadpool/pthreadpool-cpp.cc +++ b/caffe2/utils/threadpool/pthreadpool-cpp.cc @@ -82,12 +82,9 @@ void PThreadPool::run( 0u); } -// Forward declaration -size_t getDefaultNumThreads(); - -PThreadPool* pthreadpool() { +PThreadPool* pthreadpool(size_t thread_count) { static auto threadpool = - std::make_unique(getDefaultNumThreads()); + std::make_unique(thread_count); #if !(defined(WIN32)) static std::once_flag flag; std::call_once(flag, []() { @@ -105,6 +102,13 @@ PThreadPool* pthreadpool() { return threadpool.get(); } +// Forward declaration +size_t getDefaultNumThreads(); + +PThreadPool* pthreadpool() { + return pthreadpool(getDefaultNumThreads()); +} + pthreadpool_t pthreadpool_() { if (caffe2::_NoPThreadPoolGuard::is_enabled()) { return nullptr; diff --git a/caffe2/utils/threadpool/pthreadpool-cpp.h b/caffe2/utils/threadpool/pthreadpool-cpp.h index 99acff4df02..f6fc5a2d824 100644 --- a/caffe2/utils/threadpool/pthreadpool-cpp.h +++ b/caffe2/utils/threadpool/pthreadpool-cpp.h @@ -42,6 +42,7 @@ class PThreadPool final { // Return a singleton instance of PThreadPool for ATen/TH multithreading. PThreadPool* pthreadpool(); +PThreadPool* pthreadpool(size_t thread_count); // Exposes the underlying implementation of PThreadPool. // Only for use in external libraries so as to unify threading across