Init threadpool with user defined num_threads before default (#136793)

Fixes #134714 (or attempts to, idk how to test yet)

For posterity, how one can test:
1. make sure you have USE_PTHREADPOOL=1 or pull a packaged binary
2. run gdb --args python, with `r` to enter, `Ctrl-C` to pause, and `c` to get back into Python
3. import torch
4. torch.set_num_threads(1), make sure this does not trigger any additional threads getting created.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136793
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu 2024-09-27 10:25:43 -07:00 committed by PyTorch MergeBot
parent bc21689136
commit adbcaee950
4 changed files with 13 additions and 10 deletions

View File

@ -209,8 +209,8 @@ void init_num_threads() {
}
void set_num_threads(int nthreads) {
#ifndef C10_MOBILE
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
#ifndef C10_MOBILE
int no_value = NOT_SET;
if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
// num_intraop_threads either stores a positive integer or CONSUMED,
@ -229,9 +229,8 @@ void set_num_threads(int nthreads) {
}
}
#else
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
pool->set_thread_count(nthreads);
#endif // C10_MOBILE
}

View File

@ -61,9 +61,8 @@ void set_num_threads(int nthreads) {
#endif
#ifdef USE_PTHREADPOOL
// because PyTorch uses caffe2::pthreadpool() in QNNPACK
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
pool->set_thread_count(nthreads);
#endif
#if AT_MKLDNN_ENABLED()
at::native::mkldnn::clear_computation_cache();

View File

@ -82,12 +82,9 @@ void PThreadPool::run(
0u);
}
// Forward declaration
size_t getDefaultNumThreads();
PThreadPool* pthreadpool() {
PThreadPool* pthreadpool(size_t thread_count) {
static auto threadpool =
std::make_unique<PThreadPool>(getDefaultNumThreads());
std::make_unique<PThreadPool>(thread_count);
#if !(defined(WIN32))
static std::once_flag flag;
std::call_once(flag, []() {
@ -105,6 +102,13 @@ PThreadPool* pthreadpool() {
return threadpool.get();
}
// Forward declaration
size_t getDefaultNumThreads();
PThreadPool* pthreadpool() {
return pthreadpool(getDefaultNumThreads());
}
pthreadpool_t pthreadpool_() {
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
return nullptr;

View File

@ -42,6 +42,7 @@ class PThreadPool final {
// Return a singleton instance of PThreadPool for ATen/TH multithreading.
PThreadPool* pthreadpool();
PThreadPool* pthreadpool(size_t thread_count);
// Exposes the underlying implementation of PThreadPool.
// Only for use in external libraries so as to unify threading across