mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Init threadpool with user defined num_threads before default (#136793)
Fixes #134714 (or attempts to, idk how to test yet) For posterity, how one can test: 1. make sure you have USE_PTHREADPOOL=1 or pull a packaged binary 2. run gdb --args python, with `r` to enter, `Ctrl-C` to pause, and `c` to get back into Python 3. import torch 4. torch.set_num_threads(1), make sure this does not trigger any additional threads getting created. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136793 Approved by: https://github.com/albanD
This commit is contained in:
parent
bc21689136
commit
adbcaee950
|
|
@ -209,8 +209,8 @@ void init_num_threads() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_num_threads(int nthreads) {
|
void set_num_threads(int nthreads) {
|
||||||
#ifndef C10_MOBILE
|
|
||||||
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
|
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||||
|
#ifndef C10_MOBILE
|
||||||
int no_value = NOT_SET;
|
int no_value = NOT_SET;
|
||||||
if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
|
if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
|
||||||
// num_intraop_threads either stores a positive integer or CONSUMED,
|
// num_intraop_threads either stores a positive integer or CONSUMED,
|
||||||
|
|
@ -229,9 +229,8 @@ void set_num_threads(int nthreads) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
|
caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
|
||||||
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
|
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
|
||||||
pool->set_thread_count(nthreads);
|
|
||||||
#endif // C10_MOBILE
|
#endif // C10_MOBILE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,9 +61,8 @@ void set_num_threads(int nthreads) {
|
||||||
#endif
|
#endif
|
||||||
#ifdef USE_PTHREADPOOL
|
#ifdef USE_PTHREADPOOL
|
||||||
// because PyTorch uses caffe2::pthreadpool() in QNNPACK
|
// because PyTorch uses caffe2::pthreadpool() in QNNPACK
|
||||||
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
|
caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
|
||||||
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
|
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
|
||||||
pool->set_thread_count(nthreads);
|
|
||||||
#endif
|
#endif
|
||||||
#if AT_MKLDNN_ENABLED()
|
#if AT_MKLDNN_ENABLED()
|
||||||
at::native::mkldnn::clear_computation_cache();
|
at::native::mkldnn::clear_computation_cache();
|
||||||
|
|
|
||||||
|
|
@ -82,12 +82,9 @@ void PThreadPool::run(
|
||||||
0u);
|
0u);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward declaration
|
PThreadPool* pthreadpool(size_t thread_count) {
|
||||||
size_t getDefaultNumThreads();
|
|
||||||
|
|
||||||
PThreadPool* pthreadpool() {
|
|
||||||
static auto threadpool =
|
static auto threadpool =
|
||||||
std::make_unique<PThreadPool>(getDefaultNumThreads());
|
std::make_unique<PThreadPool>(thread_count);
|
||||||
#if !(defined(WIN32))
|
#if !(defined(WIN32))
|
||||||
static std::once_flag flag;
|
static std::once_flag flag;
|
||||||
std::call_once(flag, []() {
|
std::call_once(flag, []() {
|
||||||
|
|
@ -105,6 +102,13 @@ PThreadPool* pthreadpool() {
|
||||||
return threadpool.get();
|
return threadpool.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Forward declaration
|
||||||
|
size_t getDefaultNumThreads();
|
||||||
|
|
||||||
|
PThreadPool* pthreadpool() {
|
||||||
|
return pthreadpool(getDefaultNumThreads());
|
||||||
|
}
|
||||||
|
|
||||||
pthreadpool_t pthreadpool_() {
|
pthreadpool_t pthreadpool_() {
|
||||||
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
|
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ class PThreadPool final {
|
||||||
|
|
||||||
// Return a singleton instance of PThreadPool for ATen/TH multithreading.
|
// Return a singleton instance of PThreadPool for ATen/TH multithreading.
|
||||||
PThreadPool* pthreadpool();
|
PThreadPool* pthreadpool();
|
||||||
|
PThreadPool* pthreadpool(size_t thread_count);
|
||||||
|
|
||||||
// Exposes the underlying implementation of PThreadPool.
|
// Exposes the underlying implementation of PThreadPool.
|
||||||
// Only for use in external libraries so as to unify threading across
|
// Only for use in external libraries so as to unify threading across
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user