Init threadpool with user defined num_threads before default (#136793)

Fixes #134714 (or attempts to, idk how to test yet)

For posterity, how one can test:
1. make sure you have USE_PTHREADPOOL=1 or pull a packaged binary
2. run gdb --args python, with `r` to enter, `Ctrl-C` to pause, and `c` to get back into Python
3. import torch
4. torch.set_num_threads(1), make sure this does not trigger any additional threads getting created.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136793
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu 2024-09-27 10:25:43 -07:00 committed by PyTorch MergeBot
parent bc21689136
commit adbcaee950
4 changed files with 13 additions and 10 deletions

View File

@ -209,8 +209,8 @@ void init_num_threads() {
} }
void set_num_threads(int nthreads) { void set_num_threads(int nthreads) {
#ifndef C10_MOBILE
TORCH_CHECK(nthreads > 0, "Expected positive number of threads"); TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
#ifndef C10_MOBILE
int no_value = NOT_SET; int no_value = NOT_SET;
if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) { if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
// num_intraop_threads either stores a positive integer or CONSUMED, // num_intraop_threads either stores a positive integer or CONSUMED,
@ -229,9 +229,8 @@ void set_num_threads(int nthreads) {
} }
} }
#else #else
caffe2::PThreadPool* const pool = caffe2::pthreadpool(); caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
pool->set_thread_count(nthreads);
#endif // C10_MOBILE #endif // C10_MOBILE
} }

View File

@ -61,9 +61,8 @@ void set_num_threads(int nthreads) {
#endif #endif
#ifdef USE_PTHREADPOOL #ifdef USE_PTHREADPOOL
// because PyTorch uses caffe2::pthreadpool() in QNNPACK // because PyTorch uses caffe2::pthreadpool() in QNNPACK
caffe2::PThreadPool* const pool = caffe2::pthreadpool(); caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
pool->set_thread_count(nthreads);
#endif #endif
#if AT_MKLDNN_ENABLED() #if AT_MKLDNN_ENABLED()
at::native::mkldnn::clear_computation_cache(); at::native::mkldnn::clear_computation_cache();

View File

@ -82,12 +82,9 @@ void PThreadPool::run(
0u); 0u);
} }
// Forward declaration PThreadPool* pthreadpool(size_t thread_count) {
size_t getDefaultNumThreads();
PThreadPool* pthreadpool() {
static auto threadpool = static auto threadpool =
std::make_unique<PThreadPool>(getDefaultNumThreads()); std::make_unique<PThreadPool>(thread_count);
#if !(defined(WIN32)) #if !(defined(WIN32))
static std::once_flag flag; static std::once_flag flag;
std::call_once(flag, []() { std::call_once(flag, []() {
@ -105,6 +102,13 @@ PThreadPool* pthreadpool() {
return threadpool.get(); return threadpool.get();
} }
// Forward declaration
size_t getDefaultNumThreads();
PThreadPool* pthreadpool() {
return pthreadpool(getDefaultNumThreads());
}
pthreadpool_t pthreadpool_() { pthreadpool_t pthreadpool_() {
if (caffe2::_NoPThreadPoolGuard::is_enabled()) { if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
return nullptr; return nullptr;

View File

@ -42,6 +42,7 @@ class PThreadPool final {
// Return a singleton instance of PThreadPool for ATen/TH multithreading. // Return a singleton instance of PThreadPool for ATen/TH multithreading.
PThreadPool* pthreadpool(); PThreadPool* pthreadpool();
PThreadPool* pthreadpool(size_t thread_count);
// Exposes the underlying implementation of PThreadPool. // Exposes the underlying implementation of PThreadPool.
// Only for use in external libraries so as to unify threading across // Only for use in external libraries so as to unify threading across