mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Native ATen/Parallel backend (#20087)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20087 ghimport-source-id: bcfc8a86abe0893e4a380fe6f6123e2082ba4317 Differential Revision: D15248663 Pulled By: ilia-cher fbshipit-source-id: fdb7a8860c85d8202026b629cb7fa344782bd2c4
This commit is contained in:
parent
f4b434a6a5
commit
82aecfad6a
|
|
@ -31,6 +31,18 @@ CAFFE2_API int get_thread_num();
|
|||
// Checks whether the code runs in parallel region
|
||||
CAFFE2_API bool in_parallel_region();
|
||||
|
||||
/*
|
||||
parallel_for
|
||||
|
||||
begin: index at which to start applying user function
|
||||
|
||||
end: index at which to stop applying user function
|
||||
|
||||
grain_size: number of elements per chunk. impacts the degree of parallelization
|
||||
|
||||
f: user function applied in parallel to the chunks, signature:
|
||||
void f(int64_t begin, int64_t end)
|
||||
*/
|
||||
template <class F>
|
||||
inline void parallel_for(
|
||||
const int64_t begin,
|
||||
|
|
@ -90,10 +102,15 @@ CAFFE2_API void set_num_interop_threads(int);
|
|||
CAFFE2_API int get_num_interop_threads();
|
||||
|
||||
// Launches inter-op parallel task
|
||||
CAFFE2_API void launch(const std::function<void()>& func);
|
||||
CAFFE2_API void launch(std::function<void()> func);
|
||||
|
||||
// Launches intra-op parallel task
|
||||
CAFFE2_API void intraop_launch(std::function<void()> func);
|
||||
|
||||
} // namespace at
|
||||
|
||||
#if AT_PARALLEL_OPENMP
|
||||
#include <ATen/ParallelOpenMP.h>
|
||||
#elif AT_PARALLEL_NATIVE
|
||||
#include <ATen/ParallelNative.h>
|
||||
#endif
|
||||
|
|
|
|||
131
aten/src/ATen/ParallelNative.cpp
Normal file
131
aten/src/ATen/ParallelNative.cpp
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
#if AT_PARALLEL_NATIVE
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/PTThreadPool.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#ifdef TH_BLAS_MKL
|
||||
#include <mkl.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace {
|
||||
const int NOT_SET = -1;
|
||||
const int CONSUMED = -2;
|
||||
|
||||
// Number of threads set by the user
|
||||
// NOT_SET -> positive value -> CONSUMED
|
||||
// or
|
||||
// NOT_SET -> CONSUMED
|
||||
// Meaning:
|
||||
// - NOT_SET - pool not initialized, user value is not set
|
||||
// - positive value - pool not initialized, user value set
|
||||
// - CONSUMED - pool is initialized
|
||||
std::atomic<int> num_intraop_threads{NOT_SET};
|
||||
|
||||
// used with _set_in_parallel_region to mark master thread
|
||||
// as in parallel region while executing parallel primitives
|
||||
thread_local bool in_parallel_region_ = false;
|
||||
|
||||
// thread number (task_id) set by parallel primitive
|
||||
thread_local size_t thread_num_ = 0;
|
||||
|
||||
int _num_pool_threads(int nthreads) {
|
||||
if (nthreads == NOT_SET) {
|
||||
nthreads = TaskThreadPoolBase::defaultNumThreads();
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(nthreads > 0);
|
||||
}
|
||||
// minus one because of the master thread
|
||||
return nthreads - 1;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
|
||||
TaskThreadPoolBase& _get_intraop_pool() {
|
||||
static std::shared_ptr<TaskThreadPoolBase> pool =
|
||||
ThreadPoolRegistry()->Create(
|
||||
"C10",
|
||||
/* device_id */ 0,
|
||||
/* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)),
|
||||
/* create_new */ true); // create a separate thread pool for intra-op
|
||||
return *pool;
|
||||
}
|
||||
|
||||
void _set_in_parallel_region(bool in_region) {
|
||||
in_parallel_region_ = in_region;
|
||||
}
|
||||
|
||||
void _set_thread_num(size_t thread_num) {
|
||||
thread_num_ = thread_num;
|
||||
}
|
||||
|
||||
void _unset_thread_num() {
|
||||
thread_num_ = 0;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
//TODO: use OMP and MKL env. vars as default values
|
||||
void init_num_threads() {
|
||||
#ifdef _OPENMP
|
||||
omp_set_num_threads(1);
|
||||
#endif
|
||||
|
||||
#ifdef TH_BLAS_MKL
|
||||
mkl_set_num_threads(1);
|
||||
#endif
|
||||
}
|
||||
|
||||
void set_num_threads(int nthreads) {
|
||||
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
int no_value = NOT_SET;
|
||||
TORCH_CHECK(num_intraop_threads.compare_exchange_strong(no_value, nthreads),
|
||||
"Error: cannot set number of interop threads "
|
||||
"after parallel work has started or after set_num_threads call");
|
||||
}
|
||||
|
||||
int get_num_threads() {
|
||||
// not initializing pool unnecessarily,
|
||||
// because pool cannot be resized after initialization
|
||||
int nthreads = num_intraop_threads.load();
|
||||
if (nthreads > 0) {
|
||||
return nthreads;
|
||||
} else if (nthreads == NOT_SET) {
|
||||
return TaskThreadPoolBase::defaultNumThreads();
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(nthreads == CONSUMED);
|
||||
return internal::_get_intraop_pool().size() + 1;
|
||||
}
|
||||
}
|
||||
|
||||
int get_thread_num() {
|
||||
return thread_num_;
|
||||
}
|
||||
|
||||
bool in_parallel_region() {
|
||||
return in_parallel_region_ || (
|
||||
num_intraop_threads.load() == CONSUMED &&
|
||||
internal::_get_intraop_pool().inThreadPool()
|
||||
);
|
||||
}
|
||||
|
||||
void intraop_launch(std::function<void()> func) {
|
||||
if (!in_parallel_region()) {
|
||||
internal::_get_intraop_pool.run([func](){
|
||||
func();
|
||||
});
|
||||
} else {
|
||||
// execute inline if we're in parallel region
|
||||
func();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
#endif
|
||||
173
aten/src/ATen/ParallelNative.h
Normal file
173
aten/src/ATen/ParallelNative.h
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
#pragma once
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
|
||||
#include <c10/core/thread_pool.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <exception>
|
||||
|
||||
#define INTRA_OP_PARALLEL
|
||||
|
||||
namespace at {
|
||||
namespace internal {
|
||||
// internal function to get access to intra-op thread pool from
|
||||
// template parallel primitives (parallel_for, parallel_reduce)
|
||||
CAFFE2_API TaskThreadPoolBase& _get_intraop_pool();
|
||||
|
||||
// internal utility function to mark master thread as in parallel
|
||||
// region when executing parallel primitives
|
||||
CAFFE2_API void _set_in_parallel_region(bool);
|
||||
|
||||
// Simulate OMP's omp_get_thread_num() by force-setting thread local
|
||||
// task id as thread number when executing parallel primitives
|
||||
CAFFE2_API void _set_thread_num(size_t thread_num);
|
||||
CAFFE2_API void _unset_thread_num();
|
||||
}
|
||||
|
||||
template <class F>
|
||||
inline void parallel_for(
|
||||
const int64_t begin,
|
||||
const int64_t end,
|
||||
const int64_t grain_size,
|
||||
const F& f) {
|
||||
TORCH_CHECK(grain_size >= 0);
|
||||
if (begin >= end) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (((end - begin) >= grain_size) && !in_parallel_region()) {
|
||||
// choose number of tasks based on grain size and number of threads
|
||||
size_t chunk_size = divup((end - begin), get_num_threads());
|
||||
// make sure each task is at least grain_size size
|
||||
chunk_size = std::max((size_t)grain_size, chunk_size);
|
||||
size_t num_tasks = divup((end - begin), chunk_size);
|
||||
|
||||
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
||||
std::exception_ptr eptr;
|
||||
auto task = [f, &eptr, &err_flag]
|
||||
(int64_t task_id, int64_t local_start, int64_t local_end) {
|
||||
internal::_set_thread_num(task_id);
|
||||
internal::_set_in_parallel_region(true);
|
||||
try {
|
||||
f(local_start, local_end);
|
||||
} catch (...) {
|
||||
if (!err_flag.test_and_set()) {
|
||||
eptr = std::current_exception();
|
||||
}
|
||||
}
|
||||
internal::_set_in_parallel_region(false);
|
||||
internal::_unset_thread_num();
|
||||
};
|
||||
|
||||
// using shared_ptr to share ownership of the future with the lambda,
|
||||
// to ensure we don't destroy future while lambda is still
|
||||
// running in markCompleted
|
||||
std::vector<std::shared_ptr<ivalue::Future>> futures(num_tasks);
|
||||
for (size_t task_id = 1; task_id < num_tasks; ++task_id) {
|
||||
futures[task_id] = std::make_shared<ivalue::Future>();
|
||||
auto future_ptr = futures[task_id];
|
||||
int64_t local_start = begin + task_id * chunk_size;
|
||||
if (local_start < end) {
|
||||
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
|
||||
internal::_get_intraop_pool().run(
|
||||
// copy future_ptr, task_id, local_start, local_end
|
||||
[task, future_ptr, task_id, local_start, local_end]() {
|
||||
task(task_id, local_start, local_end);
|
||||
future_ptr->markCompleted(IValue());
|
||||
});
|
||||
} else {
|
||||
future_ptr->markCompleted(IValue());
|
||||
}
|
||||
}
|
||||
|
||||
int64_t first_task_end = std::min(end, (int64_t)(chunk_size + begin));
|
||||
task(0, begin, first_task_end);
|
||||
// wait for all tasks to finish
|
||||
for (size_t task_id = 1; task_id < num_tasks; ++task_id) {
|
||||
futures[task_id]->wait();
|
||||
}
|
||||
if (eptr) {
|
||||
std::rethrow_exception(eptr);
|
||||
}
|
||||
} else {
|
||||
f(begin, end);
|
||||
}
|
||||
}
|
||||
|
||||
template <class scalar_t, class F, class SF>
|
||||
inline scalar_t parallel_reduce(
|
||||
const int64_t begin,
|
||||
const int64_t end,
|
||||
const int64_t grain_size,
|
||||
const scalar_t ident,
|
||||
const F& f,
|
||||
const SF& sf) {
|
||||
TORCH_CHECK(grain_size >= 0);
|
||||
if (begin >= end) {
|
||||
return ident;
|
||||
}
|
||||
|
||||
if (((end - begin) >= grain_size) && !in_parallel_region()) {
|
||||
size_t chunk_size = divup((end - begin), get_num_threads());
|
||||
chunk_size = std::max((size_t)grain_size, chunk_size);
|
||||
size_t num_tasks = divup((end - begin), chunk_size);
|
||||
std::vector<scalar_t> results(num_tasks);
|
||||
scalar_t* results_data = results.data();
|
||||
|
||||
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
||||
std::exception_ptr eptr;
|
||||
auto task = [f, ident, results_data, &eptr, &err_flag]
|
||||
(int64_t task_id, int64_t local_start, int64_t local_end) {
|
||||
internal::_set_thread_num(task_id);
|
||||
internal::_set_in_parallel_region(true);
|
||||
try {
|
||||
results_data[task_id] = f(local_start, local_end, ident);
|
||||
} catch (...) {
|
||||
if (!err_flag.test_and_set()) {
|
||||
eptr = std::current_exception();
|
||||
}
|
||||
}
|
||||
internal::_set_in_parallel_region(false);
|
||||
internal::_unset_thread_num();
|
||||
};
|
||||
|
||||
std::vector<std::shared_ptr<ivalue::Future>> futures(num_tasks);
|
||||
for (size_t task_id = 1; task_id < num_tasks; ++task_id) {
|
||||
futures[task_id] = std::make_shared<ivalue::Future>();
|
||||
auto future_ptr = futures[task_id];
|
||||
int64_t local_start = begin + task_id * chunk_size;
|
||||
if (local_start < end) {
|
||||
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
|
||||
internal::_get_intraop_pool().run(
|
||||
// copy future_ptr, task_id, local_start, local_end
|
||||
[&, future_ptr, task_id, local_start, local_end]() {
|
||||
task(task_id, local_start, local_end);
|
||||
future_ptr->markCompleted(IValue());
|
||||
});
|
||||
} else {
|
||||
future_ptr->markCompleted(IValue());
|
||||
}
|
||||
}
|
||||
|
||||
int64_t first_task_end = std::min(end, (int64_t)(chunk_size + begin));
|
||||
task(0, begin, first_task_end);
|
||||
for (size_t task_id = 1; task_id < num_tasks; ++task_id) {
|
||||
futures[task_id]->wait();
|
||||
}
|
||||
if (eptr) {
|
||||
std::rethrow_exception(eptr);
|
||||
}
|
||||
|
||||
scalar_t result = ident;
|
||||
for (auto partial_result : results) {
|
||||
result = sf(result, partial_result);
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
return f(begin, end, ident);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
|
@ -31,7 +31,7 @@ void init_num_threads() {
|
|||
}
|
||||
|
||||
void set_num_threads(int nthreads) {
|
||||
AT_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
num_threads.store(nthreads);
|
||||
#ifdef _OPENMP
|
||||
omp_set_num_threads(nthreads);
|
||||
|
|
@ -76,5 +76,10 @@ bool in_parallel_region() {
|
|||
#endif
|
||||
}
|
||||
|
||||
void intraop_launch(std::function<void()> func) {
|
||||
// execute inline in openmp case
|
||||
func();
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ inline void parallel_for(
|
|||
const int64_t end,
|
||||
const int64_t grain_size,
|
||||
const F& f) {
|
||||
TORCH_CHECK(grain_size >= 0);
|
||||
if (begin >= end) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -56,6 +57,7 @@ inline scalar_t parallel_reduce(
|
|||
const scalar_t ident,
|
||||
const F& f,
|
||||
const SF& sf) {
|
||||
TORCH_CHECK(grain_size >= 0);
|
||||
if (begin >= end) {
|
||||
return ident;
|
||||
} else if (in_parallel_region() || get_num_threads() == 1) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#if AT_PARALLEL_OPENMP
|
||||
#if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/PTThreadPool.h>
|
||||
|
||||
|
|
@ -35,9 +35,9 @@ std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
|
|||
int pool_size,
|
||||
bool create_new) {
|
||||
// For now, the only accepted device id is 0
|
||||
AT_CHECK(device_id == 0);
|
||||
TORCH_CHECK(device_id == 0);
|
||||
// Create new thread pool
|
||||
AT_CHECK(create_new);
|
||||
TORCH_CHECK(create_new);
|
||||
return std::make_shared<PTThreadPool>(pool_size);
|
||||
}
|
||||
|
||||
|
|
@ -46,10 +46,10 @@ std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
|
|||
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
|
||||
|
||||
void set_num_interop_threads(int nthreads) {
|
||||
AT_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
|
||||
int no_value = NOT_SET;
|
||||
AT_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
|
||||
TORCH_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
|
||||
"Error: cannot set number of interop threads after parallel work "
|
||||
"has started or set_num_interop_threads called");
|
||||
}
|
||||
|
|
@ -66,7 +66,7 @@ int get_num_interop_threads() {
|
|||
}
|
||||
}
|
||||
|
||||
void launch(const std::function<void()>& func) {
|
||||
void launch(std::function<void()> func) {
|
||||
get_pool().run(func);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,23 +23,12 @@ int main() {
|
|||
at::init_num_threads();
|
||||
at::manual_seed(123);
|
||||
|
||||
test(at::get_num_threads());
|
||||
std::thread t1(test, at::get_num_threads());
|
||||
at::set_num_threads(4);
|
||||
test(4);
|
||||
std::thread t1(test, 4);
|
||||
t1.join();
|
||||
|
||||
at::set_num_threads(4);
|
||||
std::thread t2(test, at::get_num_threads());
|
||||
std::thread t3(test, at::get_num_threads());
|
||||
std::thread t4(test, at::get_num_threads());
|
||||
t4.join();
|
||||
t3.join();
|
||||
t2.join();
|
||||
|
||||
at::set_num_threads(5);
|
||||
test(at::get_num_threads());
|
||||
|
||||
// test inter-op settings
|
||||
ASSERT_EQ(at::get_num_interop_threads(), std::thread::hardware_concurrency());
|
||||
at::set_num_interop_threads(5);
|
||||
ASSERT_EQ(at::get_num_interop_threads(), 5);
|
||||
ASSERT_ANY_THROW(at::set_num_interop_threads(6));
|
||||
|
|
|
|||
|
|
@ -38,7 +38,11 @@ class C10_API TaskThreadPoolBase {
|
|||
virtual ~TaskThreadPoolBase() noexcept {}
|
||||
|
||||
static size_t defaultNumThreads() {
|
||||
return std::thread::hardware_concurrency();
|
||||
auto num_threads = std::thread::hardware_concurrency();
|
||||
#if defined(_M_X64) || defined(__x86_64__)
|
||||
num_threads /= 2;
|
||||
#endif
|
||||
return num_threads;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -865,9 +865,12 @@ target_compile_options(caffe2 PRIVATE "-DCAFFE2_BUILD_MAIN_LIB")
|
|||
|
||||
# Parallelism settings
|
||||
# OPENMP - OpenMP for intra-op, native thread pool for inter-op parallelism
|
||||
# NATIVE - using native thread pool for intra- and inter-op parallelism
|
||||
set(PARALLEL_BACKEND "OPENMP" CACHE STRING "ATen parallel backend")
|
||||
if ("${PARALLEL_BACKEND}" STREQUAL "OPENMP")
|
||||
target_compile_definitions(caffe2 PUBLIC "-DAT_PARALLEL_OPENMP=1")
|
||||
elseif ("${PARALLEL_BACKEND}" STREQUAL "NATIVE")
|
||||
target_compile_definitions(caffe2 PUBLIC "-DAT_PARALLEL_NATIVE=1")
|
||||
else()
|
||||
message(FATAL_ERROR "Unknown parallel backend: ${PARALLEL_BACKEND}")
|
||||
endif()
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -154,6 +154,7 @@
|
|||
# parallel backend to use for intra- and inter-op parallelism
|
||||
# possible values:
|
||||
# OPENMP - use OpenMP for intra-op and native backend for inter-op tasks
|
||||
# NATIVE - use native thread pool for both intra- and inter-op tasks
|
||||
|
||||
from __future__ import print_function
|
||||
from setuptools import setup, Extension, distutils, find_packages
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user