mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Init threadpool with user defined num_threads before default (#136793)
Fixes #134714 (or attempts to, idk how to test yet) For posterity, how one can test: 1. make sure you have USE_PTHREADPOOL=1 or pull a packaged binary 2. run gdb --args python, with `r` to enter, `Ctrl-C` to pause, and `c` to get back into Python 3. import torch 4. torch.set_num_threads(1), make sure this does not trigger any additional threads getting created. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136793 Approved by: https://github.com/albanD
This commit is contained in:
parent
bc21689136
commit
adbcaee950
|
|
@ -209,8 +209,8 @@ void init_num_threads() {
|
|||
}
|
||||
|
||||
void set_num_threads(int nthreads) {
|
||||
#ifndef C10_MOBILE
|
||||
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
#ifndef C10_MOBILE
|
||||
int no_value = NOT_SET;
|
||||
if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
|
||||
// num_intraop_threads either stores a positive integer or CONSUMED,
|
||||
|
|
@ -229,9 +229,8 @@ void set_num_threads(int nthreads) {
|
|||
}
|
||||
}
|
||||
#else
|
||||
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
|
||||
caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
|
||||
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
|
||||
pool->set_thread_count(nthreads);
|
||||
#endif // C10_MOBILE
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -61,9 +61,8 @@ void set_num_threads(int nthreads) {
|
|||
#endif
|
||||
#ifdef USE_PTHREADPOOL
|
||||
// because PyTorch uses caffe2::pthreadpool() in QNNPACK
|
||||
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
|
||||
caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
|
||||
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
|
||||
pool->set_thread_count(nthreads);
|
||||
#endif
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
at::native::mkldnn::clear_computation_cache();
|
||||
|
|
|
|||
|
|
@ -82,12 +82,9 @@ void PThreadPool::run(
|
|||
0u);
|
||||
}
|
||||
|
||||
// Forward declaration
|
||||
size_t getDefaultNumThreads();
|
||||
|
||||
PThreadPool* pthreadpool() {
|
||||
PThreadPool* pthreadpool(size_t thread_count) {
|
||||
static auto threadpool =
|
||||
std::make_unique<PThreadPool>(getDefaultNumThreads());
|
||||
std::make_unique<PThreadPool>(thread_count);
|
||||
#if !(defined(WIN32))
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, []() {
|
||||
|
|
@ -105,6 +102,13 @@ PThreadPool* pthreadpool() {
|
|||
return threadpool.get();
|
||||
}
|
||||
|
||||
// Forward declaration
|
||||
size_t getDefaultNumThreads();
|
||||
|
||||
PThreadPool* pthreadpool() {
|
||||
return pthreadpool(getDefaultNumThreads());
|
||||
}
|
||||
|
||||
pthreadpool_t pthreadpool_() {
|
||||
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ class PThreadPool final {
|
|||
|
||||
// Return a singleton instance of PThreadPool for ATen/TH multithreading.
|
||||
PThreadPool* pthreadpool();
|
||||
PThreadPool* pthreadpool(size_t thread_count);
|
||||
|
||||
// Exposes the underlying implementation of PThreadPool.
|
||||
// Only for use in external libraries so as to unify threading across
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user