mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[xla:cpu:ynn] Implement SlinkyThreadPool on top of WorkQueue and Worker APIs
Remove `work_queue` and `worker` that were originally forked from `xla::cpu::WorkQueue` and `Worker` PiperOrigin-RevId: 823179793
This commit is contained in:
parent
96c1b6c0a6
commit
7ba3317857
|
|
@ -55,6 +55,8 @@ class WorkQueue {
|
|||
|
||||
size_t num_partitions() const { return partitions_.size(); }
|
||||
|
||||
bool IsEmpty() const { return empty_.load(std::memory_order_relaxed); }
|
||||
|
||||
private:
|
||||
friend class Worker;
|
||||
|
||||
|
|
@ -67,9 +69,8 @@ class WorkQueue {
|
|||
size_t end;
|
||||
};
|
||||
|
||||
// An empty work queue flag to stop worker threads from looping through all
|
||||
// partitions looking for work.
|
||||
bool IsEmpty() const { return empty_.load(std::memory_order_relaxed); }
|
||||
// Sets an empty work queue flag to stop worker threads from looping through
|
||||
// all partitions looking for work.
|
||||
void SetEmpty() { empty_.store(true, std::memory_order_relaxed); }
|
||||
|
||||
// Notify that one of the workers switched to the work stealing mode.
|
||||
|
|
|
|||
|
|
@ -14,6 +14,35 @@ package_group(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "slinky_threadpool",
|
||||
srcs = ["slinky_threadpool.cc"],
|
||||
hdrs = ["slinky_threadpool.h"],
|
||||
deps = [
|
||||
"//xla/backends/cpu/runtime:work_queue",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@eigen_archive//:eigen3",
|
||||
"@slinky//slinky/base:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
ynn_cc_test(
|
||||
name = "slinky_threadpool_test",
|
||||
srcs = ["slinky_threadpool_test.cc"],
|
||||
deps = [
|
||||
":slinky_threadpool",
|
||||
"//xla/tsl/platform:env",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@slinky//slinky/base:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ynn_interop",
|
||||
srcs = ["ynn_interop.cc"],
|
||||
|
|
@ -35,6 +64,7 @@ cc_library(
|
|||
srcs = ["ynn_threadpool.cc"],
|
||||
hdrs = ["ynn_threadpool.h"],
|
||||
deps = [
|
||||
":slinky_threadpool",
|
||||
":ynn_interop",
|
||||
"@XNNPACK//ynnpack:ynnpack_h",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
|
|
@ -49,17 +79,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
ynn_cc_test(
|
||||
name = "ynn_threadpool_test",
|
||||
srcs = ["ynn_threadpool_test.cc"],
|
||||
deps = [
|
||||
":ynn_threadpool",
|
||||
"//xla/tsl/platform:env",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@slinky//slinky/base:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ynn_fusion_thunk",
|
||||
srcs = ["ynn_fusion_thunk.cc"],
|
||||
|
|
|
|||
416
third_party/xla/xla/backends/cpu/runtime/ynnpack/slinky_threadpool.cc
vendored
Normal file
416
third_party/xla/xla/backends/cpu/runtime/ynnpack/slinky_threadpool.cc
vendored
Normal file
|
|
@ -0,0 +1,416 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/backends/cpu/runtime/ynnpack/slinky_threadpool.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "slinky/base/function_ref.h"
|
||||
#include "slinky/base/ref_count.h"
|
||||
#include "slinky/base/thread_pool.h"
|
||||
#include "xla/backends/cpu/runtime/work_queue.h"
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#include "Eigen/ThreadPool"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Task
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
// Running a task can result in three states:
|
||||
//
|
||||
// kPending: The task is still being processed by the worker threads.
|
||||
// kComplete: The caller thread is the one who completed the task.
|
||||
// kDone: The task is done and all work items have been processed, however
|
||||
// the caller thread did't process any work items.
|
||||
//
|
||||
// We need this state to signal the waiter thread just once, from a thread that
|
||||
// completed the task.S
|
||||
enum class TaskState { kPending, kComplete, kDone };
|
||||
|
||||
class Task final : public SlinkyThreadPool::task {
|
||||
public:
|
||||
Task(SlinkyThreadPool::task_body body, size_t num_work_items,
|
||||
size_t num_partitions);
|
||||
|
||||
// Runs this task by processing work items in the current thread.
|
||||
TaskState Run();
|
||||
|
||||
// Returns the number of workers that are currently working on this task.
|
||||
int64_t num_workers() const;
|
||||
|
||||
bool is_empty_work_queue() const;
|
||||
bool done() const final;
|
||||
|
||||
private:
|
||||
SlinkyThreadPool::task_body body_;
|
||||
WorkQueue work_queue_;
|
||||
|
||||
ABSL_CACHELINE_ALIGNED std::atomic<size_t> worker_index_;
|
||||
ABSL_CACHELINE_ALIGNED std::atomic<size_t> pending_work_items_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Task::Task(SlinkyThreadPool::task_body body, size_t num_work_items,
|
||||
size_t num_partitions)
|
||||
: body_(std::move(body)),
|
||||
work_queue_(num_work_items, num_partitions),
|
||||
worker_index_(0),
|
||||
pending_work_items_(num_work_items) {}
|
||||
|
||||
TaskState Task::Run() {
|
||||
// If we have more workers joining the task than the number of partitions,
|
||||
// then we have to wrap around to the first partition.
|
||||
size_t worker_index = worker_index_.fetch_add(1, std::memory_order_relaxed);
|
||||
if (ABSL_PREDICT_FALSE(worker_index >= work_queue_.num_partitions())) {
|
||||
worker_index %= work_queue_.num_partitions();
|
||||
}
|
||||
|
||||
// Each worker processes the body using its own copy of the task.
|
||||
Worker w(worker_index, &work_queue_);
|
||||
size_t num_processed_work_items = 0;
|
||||
|
||||
if (std::optional<size_t> item = w.Pop()) {
|
||||
SlinkyThreadPool::task_body body = body_;
|
||||
|
||||
do {
|
||||
body(*item);
|
||||
++num_processed_work_items;
|
||||
} while ((item = w.Pop()).has_value());
|
||||
}
|
||||
|
||||
// The number of pending work items should never go below zero.
|
||||
size_t previous_work_items = pending_work_items_.fetch_sub(
|
||||
num_processed_work_items, std::memory_order_acq_rel);
|
||||
DCHECK_GE(previous_work_items, num_processed_work_items);
|
||||
|
||||
// Task is done if we have no more work items to process. Task is complete if
|
||||
// we are the one who processed the last work item.
|
||||
bool is_done = previous_work_items == num_processed_work_items;
|
||||
bool is_complete = is_done && num_processed_work_items > 0;
|
||||
|
||||
return is_complete ? TaskState::kComplete
|
||||
: is_done ? TaskState::kDone
|
||||
: TaskState::kPending;
|
||||
}
|
||||
|
||||
int64_t Task::num_workers() const {
|
||||
return worker_index_.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
bool Task::is_empty_work_queue() const { return work_queue_.IsEmpty(); }
|
||||
|
||||
bool Task::done() const {
|
||||
return pending_work_items_.load(std::memory_order_acquire) == 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SlinkyThreadPool::Impl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// We keep a stack of tasks that are currently being processed by current
|
||||
// thread, to avoid recursive calls.
|
||||
static thread_local std::vector<const Task*> task_stack; // NOLINT
|
||||
|
||||
class SlinkyThreadPool::Impl : public slinky::ref_counted<Impl> {
|
||||
public:
|
||||
explicit Impl(Eigen::ThreadPoolInterface* threadpool);
|
||||
|
||||
// Enqueues a new task into the queue and returns a reference to it.
|
||||
slinky::ref_count<Task> Enqueue(SlinkyThreadPool::task_body body,
|
||||
size_t num_work_items, size_t num_partitions);
|
||||
|
||||
// Work on the single task and return the state of the task.
|
||||
TaskState WorkOnTask(Task* task);
|
||||
|
||||
// Work on all tasks in the queue. Returns when Run out of tasks to process.
|
||||
void WorkOnTasks(const absl::Condition& condition);
|
||||
|
||||
void Await(const absl::Condition& condition);
|
||||
void AtomicCall(slinky::function_ref<void()> t);
|
||||
|
||||
// Returns true if we can schedule more workers into the underlying scheduler.
|
||||
bool CanScheduleWorkers() const;
|
||||
|
||||
// Schedules the given number of workers for the given task. Worker scheduling
|
||||
// uses recursive work splitting and early exit if the task does not need any
|
||||
// more workers, of if we reached the maximum number of scheduled workers.
|
||||
void ScheduleWorkers(int64_t num_workers, slinky::ref_count<Task> task);
|
||||
|
||||
size_t thread_count() const { return thread_count_; }
|
||||
|
||||
private:
|
||||
friend class slinky::ref_counted<Impl>;
|
||||
static void destroy(Impl* ptr) { delete ptr; }
|
||||
|
||||
// A state of the work scheduling for a given task.
|
||||
struct ScheduleState : public slinky::ref_counted<ScheduleState> {
|
||||
ScheduleState(int64_t remaining_workers, slinky::ref_count<Task> task,
|
||||
slinky::ref_count<Impl> impl)
|
||||
: remaining_workers(remaining_workers),
|
||||
task(std::move(task)),
|
||||
impl(std::move(impl)) {}
|
||||
|
||||
static void destroy(ScheduleState* ptr) { delete ptr; }
|
||||
|
||||
std::atomic<int64_t> remaining_workers;
|
||||
slinky::ref_count<Task> task;
|
||||
slinky::ref_count<Impl> impl;
|
||||
};
|
||||
|
||||
// Worker scheduling function for the underlying scheduler.
|
||||
template <bool release_impl_ref>
|
||||
static void ScheduleWorkers(ScheduleState* context);
|
||||
|
||||
// Dequeues a pending task from the queue.
|
||||
slinky::ref_count<Task> Dequeue();
|
||||
|
||||
// Signals all waiter threads waiting on the waiter mutex.
|
||||
void SignalWaiters();
|
||||
|
||||
Eigen::ThreadPoolInterface* threadpool_;
|
||||
size_t thread_count_;
|
||||
|
||||
std::deque<slinky::ref_count<Task>> tasks_ ABSL_GUARDED_BY(tasks_mutex_);
|
||||
|
||||
// A mutex for guarding mutable state accessed concurrently.
|
||||
ABSL_CACHELINE_ALIGNED absl::Mutex tasks_mutex_;
|
||||
|
||||
// A mutex for signalling threads waiting on the tasks or conditions.
|
||||
ABSL_CACHELINE_ALIGNED absl::Mutex waiter_mutex_;
|
||||
};
|
||||
|
||||
SlinkyThreadPool::Impl::Impl(Eigen::ThreadPoolInterface* threadpool)
|
||||
: threadpool_(threadpool),
|
||||
thread_count_(threadpool_ ? threadpool_->NumThreads() : 0) {}
|
||||
|
||||
slinky::ref_count<Task> SlinkyThreadPool::Impl::Enqueue(
|
||||
SlinkyThreadPool::task_body body, size_t num_work_items,
|
||||
size_t num_partitions) {
|
||||
slinky::ref_count<Task> task(
|
||||
new Task(std::move(body), num_work_items, num_partitions));
|
||||
|
||||
absl::MutexLock lock(tasks_mutex_);
|
||||
return tasks_.emplace_back(std::move(task));
|
||||
}
|
||||
|
||||
slinky::ref_count<Task> SlinkyThreadPool::Impl::Dequeue() {
|
||||
absl::MutexLock lock(tasks_mutex_);
|
||||
|
||||
for (auto i = tasks_.begin(); i != tasks_.end();) {
|
||||
slinky::ref_count<Task>& task = *i;
|
||||
|
||||
// Task doesn't have any more work items to process.
|
||||
if (ABSL_PREDICT_FALSE(task->is_empty_work_queue())) {
|
||||
i = tasks_.erase(i);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Don't Run the same task multiple times on the same thread.
|
||||
if (ABSL_PREDICT_FALSE(absl::c_contains(task_stack, &*task))) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
|
||||
return task;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TaskState SlinkyThreadPool::Impl::WorkOnTask(Task* task) {
|
||||
DCHECK(absl::c_find(task_stack, task) == task_stack.end());
|
||||
|
||||
task_stack.push_back(task);
|
||||
TaskState state = task->Run();
|
||||
task_stack.pop_back();
|
||||
|
||||
// If we are the one who completed the task, we signal the waiters to wake upS
|
||||
// any threads that are waiting for the task completion. If the task was
|
||||
// completed by another worker, we do nothing to avoid the cost of waking up
|
||||
// the same thread multiple times.
|
||||
if (ABSL_PREDICT_FALSE(state == TaskState::kComplete)) {
|
||||
SignalWaiters();
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
void SlinkyThreadPool::Impl::WorkOnTasks(const absl::Condition& condition) {
|
||||
while (slinky::ref_count<Task> task = Dequeue()) {
|
||||
WorkOnTask(&*task);
|
||||
|
||||
if (ABSL_PREDICT_TRUE(condition.Eval())) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SlinkyThreadPool::Impl::Await(const absl::Condition& condition) {
|
||||
if (ABSL_PREDICT_FALSE(!condition.Eval())) {
|
||||
absl::MutexLock lock(waiter_mutex_);
|
||||
waiter_mutex_.Await(condition);
|
||||
}
|
||||
}
|
||||
|
||||
void SlinkyThreadPool::Impl::SignalWaiters() {
|
||||
absl::MutexLock lock(waiter_mutex_);
|
||||
}
|
||||
|
||||
void SlinkyThreadPool::Impl::AtomicCall(slinky::function_ref<void()> t) {
|
||||
absl::MutexLock lock(waiter_mutex_);
|
||||
t();
|
||||
}
|
||||
|
||||
bool SlinkyThreadPool::Impl::CanScheduleWorkers() const {
|
||||
// One reference is owned by the parent SlinkyThreadPool, every other
|
||||
// reference is owned by a worker scheduled into the underlying scheduler.
|
||||
return ref_count() < 1 + thread_count();
|
||||
}
|
||||
|
||||
void SlinkyThreadPool::Impl::ScheduleWorkers(int64_t num_workers,
|
||||
slinky::ref_count<Task> task) {
|
||||
if (ABSL_PREDICT_TRUE(num_workers > 0 && CanScheduleWorkers())) {
|
||||
slinky::ref_count<ScheduleState> state(
|
||||
new ScheduleState(num_workers - 1, std::move(task), {this}));
|
||||
threadpool_->Schedule([state = state.take()]() {
|
||||
ScheduleWorkers</*release_impl_ref=*/false>(state);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool release_impl_ref>
|
||||
void SlinkyThreadPool::Impl::ScheduleWorkers(ScheduleState* context) {
|
||||
auto state = slinky::ref_count<ScheduleState>::assume(context);
|
||||
|
||||
// We recursively keep scheduling workers into the underlying scheduler.
|
||||
// This is more efficient than scheduling them sequentially from a single
|
||||
// thread, because workers can start processing the task sooner and we
|
||||
// distribute thread wake-ups evenly across underlying threads.
|
||||
static constexpr int32_t kNumRecursiveWorkers = 2;
|
||||
|
||||
for (size_t i = 0; i < kNumRecursiveWorkers; ++i) {
|
||||
bool schedule_worker =
|
||||
state->impl->CanScheduleWorkers() &&
|
||||
!state->task->is_empty_work_queue() &&
|
||||
state->remaining_workers.fetch_sub(1, std::memory_order_relaxed) > 0;
|
||||
|
||||
if (ABSL_PREDICT_TRUE(!schedule_worker)) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Add +1 reference to account for the scheduled worker, as we use `impl`
|
||||
// reference count to track the number of active workers.
|
||||
state->impl->add_ref();
|
||||
state->impl->threadpool_->Schedule(
|
||||
[state = slinky::ref_count<ScheduleState>(state).take()]() {
|
||||
SlinkyThreadPool::Impl::ScheduleWorkers</*release_impl_ref=*/true>(
|
||||
state);
|
||||
});
|
||||
}
|
||||
|
||||
// Keep processing tasks from the queue until we are out of tasks.
|
||||
static constexpr bool kFalse = false;
|
||||
state->impl->WorkOnTasks(absl::Condition(&kFalse));
|
||||
|
||||
// One `impl` reference implicitly owned by the `state`, every additional
|
||||
// reference is added and released explicitly by the worker task.
|
||||
if constexpr (release_impl_ref) {
|
||||
state->impl->release();
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SlinkyThreadPool
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SlinkyThreadPool::SlinkyThreadPool(Eigen::ThreadPoolDevice* device)
|
||||
: impl_(new Impl(device ? device->getPool() : nullptr)) {}
|
||||
|
||||
SlinkyThreadPool::SlinkyThreadPool(Eigen::ThreadPoolInterface* threadpool)
|
||||
: impl_(new Impl(threadpool)) {}
|
||||
|
||||
SlinkyThreadPool::~SlinkyThreadPool() = default;
|
||||
|
||||
slinky::ref_count<SlinkyThreadPool::task> SlinkyThreadPool::enqueue(
|
||||
size_t n, task_body t, int32_t max_workers) {
|
||||
CHECK_GE(max_workers, n);
|
||||
|
||||
// Don't create more partitions than the number of threads. Also make sure
|
||||
// that we have at least one partition (if we don't have a scheduler).
|
||||
size_t num_partitions = std::min<size_t>(n, thread_count());
|
||||
num_partitions = std::max<size_t>(1, num_partitions);
|
||||
|
||||
auto task = impl_->Enqueue(std::move(t), n, num_partitions);
|
||||
|
||||
// If we don't have any worker threads, we return a task to the caller, and
|
||||
// assume that the caller will wait on it.
|
||||
if (ABSL_PREDICT_FALSE(impl_->thread_count() == 0)) {
|
||||
return task;
|
||||
}
|
||||
|
||||
// We assume that the caller will immediately start working on the task, so we
|
||||
// need to schedule workers only for the remaining number of partitions.
|
||||
impl_->ScheduleWorkers(/*num_workers=*/num_partitions - 1, task);
|
||||
|
||||
return task;
|
||||
}
|
||||
|
||||
void SlinkyThreadPool::wait_for(task* t) {
|
||||
Task* task = static_cast<Task*>(t);
|
||||
TaskState state = impl_->WorkOnTask(task);
|
||||
|
||||
// If the task is complete or done, we are immediately done with waiting.
|
||||
if (ABSL_PREDICT_TRUE(state == TaskState::kComplete ||
|
||||
state == TaskState::kDone)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Switch to the work stealing mode and work on other tasks in the queue
|
||||
// until the given task is done.
|
||||
impl_->WorkOnTasks(absl::Condition(task, &Task::done));
|
||||
impl_->Await(absl::Condition(task, &Task::done));
|
||||
}
|
||||
|
||||
void SlinkyThreadPool::wait_for(predicate_ref condition) {
|
||||
impl_->WorkOnTasks(absl::Condition(&condition));
|
||||
impl_->Await(absl::Condition(&condition));
|
||||
}
|
||||
|
||||
void SlinkyThreadPool::atomic_call(slinky::function_ref<void()> t) {
|
||||
impl_->AtomicCall(t);
|
||||
}
|
||||
|
||||
int SlinkyThreadPool::thread_count() const { return impl_->thread_count(); }
|
||||
|
||||
} // namespace xla::cpu
|
||||
61
third_party/xla/xla/backends/cpu/runtime/ynnpack/slinky_threadpool.h
vendored
Normal file
61
third_party/xla/xla/backends/cpu/runtime/ynnpack/slinky_threadpool.h
vendored
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_BACKENDS_CPU_RUNTIME_YNNPACK_SLINKY_THREADPOOL_H_
|
||||
#define XLA_BACKENDS_CPU_RUNTIME_YNNPACK_SLINKY_THREADPOOL_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include "slinky/base/function_ref.h"
|
||||
#include "slinky/base/ref_count.h"
|
||||
#include "slinky/base/thread_pool.h"
|
||||
|
||||
namespace Eigen {
|
||||
struct ThreadPoolDevice;
|
||||
class ThreadPoolInterface;
|
||||
} // namespace Eigen
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
// This is an implementation of slinky::thread_pool, using absl::Mutex for
|
||||
// synchronization, and dispatches work to Eigen::ThreadPoolInterface.
|
||||
class SlinkyThreadPool final : public slinky::thread_pool {
|
||||
public:
|
||||
explicit SlinkyThreadPool(Eigen::ThreadPoolDevice* device);
|
||||
explicit SlinkyThreadPool(Eigen::ThreadPoolInterface* threadpool);
|
||||
~SlinkyThreadPool() final;
|
||||
|
||||
SlinkyThreadPool(SlinkyThreadPool&&) = default;
|
||||
SlinkyThreadPool& operator=(SlinkyThreadPool&&) = default;
|
||||
|
||||
slinky::ref_count<task> enqueue(size_t n, task_body t,
|
||||
int32_t max_workers) final;
|
||||
|
||||
void wait_for(task* t) final;
|
||||
void wait_for(predicate_ref condition) final;
|
||||
|
||||
void atomic_call(slinky::function_ref<void()> t) final;
|
||||
|
||||
int thread_count() const final;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
slinky::ref_count<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_RUNTIME_YNNPACK_SLINKY_THREADPOOL_H_
|
||||
|
|
@ -13,95 +13,81 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/slinky_threadpool.h"
|
||||
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "slinky/base/thread_pool.h"
|
||||
#include "xla/tsl/platform/env.h"
|
||||
#include "xla/tsl/platform/threadpool.h"
|
||||
|
||||
namespace Eigen {
|
||||
class ThreadPoolInterface;
|
||||
} // namespace Eigen
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
TEST(YnnThreadpoolImpl, inline_scheduling) {
|
||||
auto ynn_threadpool =
|
||||
CreateYnnThreadpool(static_cast<Eigen::ThreadPoolInterface*>(nullptr));
|
||||
auto thread_pool =
|
||||
reinterpret_cast<slinky::thread_pool*>(ynn_threadpool->get());
|
||||
TEST(SlinkyThreadPoolTest, InlineScheduling) {
|
||||
SlinkyThreadPool thread_pool(
|
||||
static_cast<Eigen::ThreadPoolInterface*>(nullptr));
|
||||
|
||||
static constexpr size_t size = 10000;
|
||||
|
||||
std::vector<int32_t> data(size, 0);
|
||||
auto inc = [&](size_t i) { data[i]++; };
|
||||
|
||||
thread_pool->parallel_for(size, inc);
|
||||
thread_pool.parallel_for(size, inc);
|
||||
|
||||
std::vector<int32_t> expected(size, 1);
|
||||
EXPECT_EQ(data, expected);
|
||||
}
|
||||
|
||||
TEST(YnnThreadpoolImpl, single_loop) {
|
||||
TEST(SlinkyThreadPoolTest, SingleLoop) {
|
||||
tsl::thread::ThreadPool test_thread_pool(tsl::Env::Default(), "test", 4);
|
||||
auto ynn_threadpool =
|
||||
CreateYnnThreadpool(test_thread_pool.AsEigenThreadPool());
|
||||
auto thread_pool =
|
||||
reinterpret_cast<slinky::thread_pool*>(ynn_threadpool->get());
|
||||
SlinkyThreadPool thread_pool(test_thread_pool.AsEigenThreadPool());
|
||||
|
||||
static constexpr size_t size = 10000;
|
||||
|
||||
std::vector<int32_t> data(size, 0);
|
||||
auto inc = [&](size_t i) { data[i]++; };
|
||||
|
||||
thread_pool->parallel_for(size, inc);
|
||||
thread_pool.parallel_for(size, inc);
|
||||
|
||||
std::vector<int32_t> expected(size, 1);
|
||||
EXPECT_EQ(data, expected);
|
||||
}
|
||||
|
||||
TEST(YnnThreadpoolImpl, loop_chain) {
|
||||
TEST(SlinkyThreadPoolTest, LoopChain) {
|
||||
tsl::thread::ThreadPool test_thread_pool(tsl::Env::Default(), "test", 4);
|
||||
auto ynn_threadpool =
|
||||
CreateYnnThreadpool(test_thread_pool.AsEigenThreadPool());
|
||||
auto thread_pool =
|
||||
reinterpret_cast<slinky::thread_pool*>(ynn_threadpool->get());
|
||||
SlinkyThreadPool thread_pool(test_thread_pool.AsEigenThreadPool());
|
||||
|
||||
static constexpr size_t size = 10000;
|
||||
|
||||
std::vector<int32_t> data(size, 0);
|
||||
auto inc = [&](size_t i) { data[i]++; };
|
||||
|
||||
thread_pool->parallel_for(size, inc);
|
||||
thread_pool->parallel_for(size, inc);
|
||||
thread_pool->parallel_for(size, inc);
|
||||
thread_pool->parallel_for(size, inc);
|
||||
thread_pool->parallel_for(size, inc);
|
||||
thread_pool.parallel_for(size, inc);
|
||||
thread_pool.parallel_for(size, inc);
|
||||
thread_pool.parallel_for(size, inc);
|
||||
thread_pool.parallel_for(size, inc);
|
||||
thread_pool.parallel_for(size, inc);
|
||||
|
||||
std::vector<int32_t> expected(size, 5);
|
||||
EXPECT_EQ(data, expected);
|
||||
}
|
||||
|
||||
TEST(YnnThreadpoolImpl, nested_loops) {
|
||||
TEST(SlinkyThreadPoolTest, NestedLoops) {
|
||||
tsl::thread::ThreadPool test_thread_pool(tsl::Env::Default(), "test", 4);
|
||||
auto ynn_threadpool =
|
||||
CreateYnnThreadpool(test_thread_pool.AsEigenThreadPool());
|
||||
auto thread_pool =
|
||||
reinterpret_cast<slinky::thread_pool*>(ynn_threadpool->get());
|
||||
SlinkyThreadPool thread_pool(test_thread_pool.AsEigenThreadPool());
|
||||
|
||||
static constexpr size_t size = 100;
|
||||
|
||||
std::array<std::atomic<int32_t>, size> data = {{0}};
|
||||
auto inc = [&](size_t i) { data[i]++; };
|
||||
|
||||
thread_pool->parallel_for(
|
||||
size, [&](size_t i) { thread_pool->parallel_for(size, inc); });
|
||||
thread_pool.parallel_for(
|
||||
size, [&](size_t i) { thread_pool.parallel_for(size, inc); });
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
EXPECT_EQ(data[i], size);
|
||||
|
|
@ -13,29 +13,15 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ynnpack/include/ynnpack.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "slinky/base/function_ref.h"
|
||||
#include "slinky/base/ref_count.h"
|
||||
#include "slinky/base/thread_pool.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/slinky_threadpool.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#include "Eigen/ThreadPool"
|
||||
|
|
@ -43,550 +29,11 @@ limitations under the License.
|
|||
|
||||
namespace xla::cpu {
|
||||
|
||||
namespace {
|
||||
|
||||
// This is an implementation of slinky::thread_pool, using absl::Mutex for
|
||||
// synchronization, and dispatches work to Eigen::ThreadPoolInterface.
|
||||
class YnnThreadpoolImpl final : public slinky::thread_pool {
|
||||
public:
|
||||
explicit YnnThreadpoolImpl(Eigen::ThreadPoolDevice* device);
|
||||
explicit YnnThreadpoolImpl(Eigen::ThreadPoolInterface* threadpool);
|
||||
~YnnThreadpoolImpl() final;
|
||||
|
||||
YnnThreadpoolImpl(YnnThreadpoolImpl&&) = delete;
|
||||
YnnThreadpoolImpl& operator=(YnnThreadpoolImpl&&) = delete;
|
||||
|
||||
slinky::ref_count<task> enqueue(size_t n, task_body t,
|
||||
int32_t max_workers) final;
|
||||
|
||||
void wait_for(task* t) final;
|
||||
void wait_for(predicate_ref condition) final;
|
||||
|
||||
void atomic_call(slinky::function_ref<void()> t) final;
|
||||
|
||||
int thread_count() const final;
|
||||
|
||||
private:
|
||||
class impl;
|
||||
slinky::ref_count<impl> impl_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// work_queue
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Forward declare.
|
||||
class worker;
|
||||
|
||||
// A work queue that partitions `num_work_items` work items into
|
||||
// `num_partitions` partitions processed by parallel workers.
|
||||
class work_queue {
|
||||
public:
|
||||
work_queue(size_t num_work_items, size_t num_partitions);
|
||||
|
||||
// Returns the next work item in the given partition. Returns std::nullopt
|
||||
// if the partition is complete.
|
||||
std::optional<size_t> pop_work_item(size_t partition_index);
|
||||
|
||||
// Return the partition [begin, end) work item range.
|
||||
std::pair<size_t, size_t> partition_range(size_t partition_index) const;
|
||||
|
||||
size_t num_partitions() const { return partitions_.size(); }
|
||||
|
||||
// If work queue is empty, it means that all work items are being processed by
|
||||
// the workers, and the task will be done once all workers complete.
|
||||
bool is_empty() const { return empty_.load(std::memory_order_relaxed); }
|
||||
|
||||
private:
|
||||
friend class worker;
|
||||
|
||||
// Work items partition tracking the next work item to process.
|
||||
struct partition {
|
||||
void initialize(size_t begin, size_t end);
|
||||
|
||||
// Tracks index of the next work item in the assigned partition.
|
||||
ABSL_CACHELINE_ALIGNED std::atomic<size_t> index;
|
||||
size_t begin;
|
||||
size_t end;
|
||||
};
|
||||
|
||||
void set_empty() { empty_.store(true, std::memory_order_relaxed); }
|
||||
|
||||
absl::FixedArray<partition, 32> partitions_;
|
||||
ABSL_CACHELINE_ALIGNED std::atomic<bool> empty_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void work_queue::partition::initialize(size_t begin, size_t end) {
|
||||
index.store(begin, std::memory_order_relaxed);
|
||||
this->begin = begin;
|
||||
this->end = end;
|
||||
}
|
||||
|
||||
work_queue::work_queue(size_t num_work_items, size_t num_partitions)
|
||||
: partitions_(num_partitions), empty_(num_work_items == 0) {
|
||||
size_t partition_size = num_work_items / num_partitions;
|
||||
size_t rem_work = num_work_items % num_partitions;
|
||||
for (size_t i = 0, begin = 0, end = 0; i < num_partitions; ++i, begin = end) {
|
||||
end = begin + partition_size + ((i < rem_work) ? 1 : 0);
|
||||
partitions_[i].initialize(begin, end);
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<size_t> work_queue::pop_work_item(size_t partition_index) {
|
||||
DCHECK(partition_index < partitions_.size()) << "Invalid partition index";
|
||||
partition& partition = partitions_.data()[partition_index];
|
||||
|
||||
// Check if partition is already empty.
|
||||
if (size_t index = partition.index.load(std::memory_order_relaxed);
|
||||
ABSL_PREDICT_FALSE(index >= partition.end)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Try to acquire the next work item in the partition.
|
||||
size_t index = partition.index.fetch_add(1, std::memory_order_relaxed);
|
||||
return ABSL_PREDICT_FALSE(index >= partition.end) ? std::nullopt
|
||||
: std::make_optional(index);
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> work_queue::partition_range(
|
||||
size_t partition_index) const {
|
||||
DCHECK(partition_index < partitions_.size()) << "Invalid partition index";
|
||||
return {partitions_[partition_index].begin, partitions_[partition_index].end};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// worker
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
// Worker processes work items from the work queue starting from the assigned
|
||||
// work partition. Once the assigned partition is complete it tries to pop
|
||||
// the work item from the next partition. Once the work queue is empty (the
|
||||
// worker wraps around to the initial partition) it returns and empty work item.
|
||||
class worker {
|
||||
public:
|
||||
worker(size_t partition_index, work_queue* queue);
|
||||
|
||||
std::optional<size_t> pop_work_item();
|
||||
|
||||
private:
|
||||
size_t initial_partition_index_;
|
||||
size_t partition_index_;
|
||||
work_queue* queue_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
worker::worker(size_t partition_index, work_queue* queue)
|
||||
: initial_partition_index_(partition_index),
|
||||
partition_index_(partition_index),
|
||||
queue_(queue) {}
|
||||
|
||||
std::optional<size_t> worker::pop_work_item() {
|
||||
std::optional<size_t> work = queue_->pop_work_item(partition_index_);
|
||||
if (ABSL_PREDICT_TRUE(work)) {
|
||||
return work;
|
||||
}
|
||||
|
||||
while (!work.has_value() && !queue_->is_empty()) {
|
||||
// Wrap around to the first partition.
|
||||
if (ABSL_PREDICT_FALSE(++partition_index_ >= queue_->num_partitions())) {
|
||||
partition_index_ = 0;
|
||||
}
|
||||
|
||||
// We checked all partitions and got back to the partition we started from.
|
||||
if (ABSL_PREDICT_FALSE(partition_index_ == initial_partition_index_)) {
|
||||
queue_->set_empty();
|
||||
break;
|
||||
}
|
||||
|
||||
work = queue_->pop_work_item(partition_index_);
|
||||
}
|
||||
|
||||
return work;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// task_impl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
// Running a task can result in three states:
|
||||
//
|
||||
// kPending: The task is still being processed by the worker threads.
|
||||
// kComplete: The caller thread is the one who completed the task.
|
||||
// kDone: The task is done and all work items have been processed, however
|
||||
// the caller thread did't process any work items.
|
||||
//
|
||||
// We need this state to signal the waiter thread just once, from a thread that
|
||||
// completed the task.
|
||||
enum class task_state { kPending, kComplete, kDone };
|
||||
|
||||
class task_impl final : public YnnThreadpoolImpl::task {
|
||||
public:
|
||||
task_impl(YnnThreadpoolImpl::task_body body, size_t num_work_items,
|
||||
size_t num_partitions);
|
||||
|
||||
// Runs this task by process work items in the current thread.
|
||||
task_state run();
|
||||
|
||||
int64_t num_workers() const;
|
||||
bool is_empty_work_queue() const;
|
||||
bool done() const final;
|
||||
|
||||
private:
|
||||
YnnThreadpoolImpl::task_body body_;
|
||||
work_queue work_queue_;
|
||||
|
||||
ABSL_CACHELINE_ALIGNED std::atomic<size_t> worker_index_;
|
||||
ABSL_CACHELINE_ALIGNED std::atomic<size_t> pending_work_items_;
|
||||
};
|
||||
|
||||
task_impl::task_impl(YnnThreadpoolImpl::task_body body, size_t num_work_items,
|
||||
size_t num_partitions)
|
||||
: body_(std::move(body)),
|
||||
work_queue_(num_work_items, num_partitions),
|
||||
worker_index_(0),
|
||||
pending_work_items_(num_work_items) {}
|
||||
|
||||
task_state task_impl::run() {
|
||||
// If we have more workers joining the task than the number of partitions,
|
||||
// then we have to wrap around to the first partition.
|
||||
size_t worker_index = worker_index_.fetch_add(1, std::memory_order_relaxed);
|
||||
if (ABSL_PREDICT_FALSE(worker_index >= work_queue_.num_partitions())) {
|
||||
worker_index %= work_queue_.num_partitions();
|
||||
}
|
||||
|
||||
// Each worker processes the body using its own copy of the task.
|
||||
worker w(worker_index, &work_queue_);
|
||||
size_t num_processed_work_items = 0;
|
||||
|
||||
if (std::optional<size_t> item = w.pop_work_item()) {
|
||||
YnnThreadpoolImpl::task_body body = body_;
|
||||
|
||||
do {
|
||||
body(*item);
|
||||
++num_processed_work_items;
|
||||
} while ((item = w.pop_work_item()).has_value());
|
||||
}
|
||||
|
||||
// The number of pending work items should never go below zero.
|
||||
size_t previous_work_items = pending_work_items_.fetch_sub(
|
||||
num_processed_work_items, std::memory_order_acq_rel);
|
||||
DCHECK_GE(previous_work_items, num_processed_work_items);
|
||||
|
||||
// Task is done if we have no more work items to process. Task is complete if
|
||||
// we are the one who processed the last work item.
|
||||
bool is_done = previous_work_items == num_processed_work_items;
|
||||
bool is_complete = is_done && num_processed_work_items > 0;
|
||||
|
||||
return is_complete ? task_state::kComplete
|
||||
: is_done ? task_state::kDone
|
||||
: task_state::kPending;
|
||||
}
|
||||
|
||||
int64_t task_impl::num_workers() const {
|
||||
return worker_index_.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
bool task_impl::is_empty_work_queue() const { return work_queue_.is_empty(); }
|
||||
|
||||
bool task_impl::done() const {
|
||||
return pending_work_items_.load(std::memory_order_acquire) == 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YnnThreadpoolImpl::impl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// We keep a stack of tasks that are currently being processed by current
|
||||
// thread, to avoid recursive calls.
|
||||
static thread_local std::vector<const task_impl*> task_stack; // NOLINT
|
||||
|
||||
class YnnThreadpoolImpl::impl : public slinky::ref_counted<impl> {
|
||||
public:
|
||||
explicit impl(Eigen::ThreadPoolInterface* threadpool);
|
||||
|
||||
// Work on the single task and return the state of the task.
|
||||
task_state work_on_task(task_impl* task);
|
||||
|
||||
// Work on all tasks in the queue. Returns when run out of tasks to process.
|
||||
void work_on_tasks(const absl::Condition& condition);
|
||||
|
||||
// Enqueues a new task into the queue and returns a reference to it.
|
||||
slinky::ref_count<task_impl> enqueue(YnnThreadpoolImpl::task_body body,
|
||||
size_t num_work_items,
|
||||
size_t num_partitions);
|
||||
|
||||
void await(const absl::Condition& condition);
|
||||
|
||||
void atomic_call(slinky::function_ref<void()> t);
|
||||
|
||||
// Returns true if we can schedule more workers into the underlying scheduler.
|
||||
bool can_schedule_workers() const;
|
||||
|
||||
// Schedules the given number of workers for the given task. Worker scheduling
|
||||
// uses recursive work splitting and early exit if the task does not need any
|
||||
// more workers, of if we reached the maximum number of scheduled workers.
|
||||
void schedule_workers(int64_t num_workers, slinky::ref_count<task_impl> task);
|
||||
|
||||
size_t thread_count() const { return thread_count_; }
|
||||
|
||||
private:
|
||||
friend class slinky::ref_counted<impl>;
|
||||
static void destroy(impl* ptr) { delete ptr; }
|
||||
|
||||
// A state of the work scheduling for a given task.
|
||||
struct schedule_state : public slinky::ref_counted<schedule_state> {
|
||||
schedule_state(int64_t remaining_workers, slinky::ref_count<task_impl> task,
|
||||
slinky::ref_count<impl> impl)
|
||||
: remaining_workers(remaining_workers),
|
||||
task(std::move(task)),
|
||||
impl(std::move(impl)) {}
|
||||
|
||||
static void destroy(schedule_state* ptr) { delete ptr; }
|
||||
|
||||
std::atomic<int64_t> remaining_workers;
|
||||
slinky::ref_count<task_impl> task;
|
||||
slinky::ref_count<impl> impl;
|
||||
};
|
||||
|
||||
// Worker scheduling function for the underlying scheduler.
|
||||
template <bool release_impl_ref>
|
||||
static void schedule_workers(schedule_state* context);
|
||||
|
||||
// Dequeues a pending task from the queue.
|
||||
slinky::ref_count<task_impl> dequeue();
|
||||
|
||||
// Signals all waiter threads waiting on the waiter mutex.
|
||||
void signal_waiters();
|
||||
|
||||
Eigen::ThreadPoolInterface* threadpool_;
|
||||
size_t thread_count_;
|
||||
|
||||
std::deque<slinky::ref_count<task_impl>> tasks_ ABSL_GUARDED_BY(tasks_mutex_);
|
||||
|
||||
// A mutex for guarding mutable state accessed concurrently.
|
||||
ABSL_CACHELINE_ALIGNED absl::Mutex tasks_mutex_;
|
||||
|
||||
// A mutex for signalling threads waiting on the tasks or conditions.
|
||||
ABSL_CACHELINE_ALIGNED absl::Mutex waiter_mutex_;
|
||||
};
|
||||
|
||||
YnnThreadpoolImpl::impl::impl(Eigen::ThreadPoolInterface* threadpool)
|
||||
: threadpool_(threadpool),
|
||||
thread_count_(threadpool_ ? threadpool_->NumThreads() : 0) {}
|
||||
|
||||
slinky::ref_count<task_impl> YnnThreadpoolImpl::impl::enqueue(
|
||||
YnnThreadpoolImpl::task_body body, size_t num_work_items,
|
||||
size_t num_partitions) {
|
||||
slinky::ref_count<task_impl> task(
|
||||
new task_impl(std::move(body), num_work_items, num_partitions));
|
||||
|
||||
absl::MutexLock lock(tasks_mutex_);
|
||||
return tasks_.emplace_back(std::move(task));
|
||||
}
|
||||
|
||||
slinky::ref_count<task_impl> YnnThreadpoolImpl::impl::dequeue() {
|
||||
absl::MutexLock lock(tasks_mutex_);
|
||||
|
||||
for (auto i = tasks_.begin(); i != tasks_.end();) {
|
||||
slinky::ref_count<task_impl>& task = *i;
|
||||
|
||||
// Task doesn't have any more work items to process.
|
||||
if (ABSL_PREDICT_FALSE(task->is_empty_work_queue())) {
|
||||
i = tasks_.erase(i);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Don't run the same task multiple times on the same thread.
|
||||
if (ABSL_PREDICT_FALSE(absl::c_contains(task_stack, &*task))) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
|
||||
return task;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
task_state YnnThreadpoolImpl::impl::work_on_task(task_impl* task) {
|
||||
DCHECK(absl::c_find(task_stack, task) == task_stack.end());
|
||||
|
||||
task_stack.push_back(task);
|
||||
task_state state = task->run();
|
||||
task_stack.pop_back();
|
||||
|
||||
// If we are the one who completed the task, we signal the waiters to wake upS
|
||||
// any threads that are waiting for the task completion. If the task was
|
||||
// completed by another worker, we do nothing to avoid the cost of waking up
|
||||
// the same thread multiple times.
|
||||
if (ABSL_PREDICT_FALSE(state == task_state::kComplete)) {
|
||||
signal_waiters();
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
void YnnThreadpoolImpl::impl::work_on_tasks(const absl::Condition& condition) {
|
||||
while (slinky::ref_count<task_impl> task = dequeue()) {
|
||||
work_on_task(&*task);
|
||||
|
||||
if (ABSL_PREDICT_TRUE(condition.Eval())) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void YnnThreadpoolImpl::impl::await(const absl::Condition& condition) {
|
||||
if (ABSL_PREDICT_FALSE(!condition.Eval())) {
|
||||
absl::MutexLock lock(waiter_mutex_);
|
||||
waiter_mutex_.Await(condition);
|
||||
}
|
||||
}
|
||||
|
||||
void YnnThreadpoolImpl::impl::signal_waiters() {
|
||||
absl::MutexLock lock(waiter_mutex_);
|
||||
}
|
||||
|
||||
void YnnThreadpoolImpl::impl::atomic_call(slinky::function_ref<void()> t) {
|
||||
absl::MutexLock lock(waiter_mutex_);
|
||||
t();
|
||||
}
|
||||
|
||||
bool YnnThreadpoolImpl::impl::can_schedule_workers() const {
|
||||
// One reference is owned by the parent YnnThreadpoolImpl, every other
|
||||
// reference is owned by a worker scheduled into the underlying scheduler.
|
||||
return ref_count() < 1 + thread_count();
|
||||
}
|
||||
|
||||
void YnnThreadpoolImpl::impl::schedule_workers(
|
||||
int64_t num_workers, slinky::ref_count<task_impl> task) {
|
||||
if (ABSL_PREDICT_TRUE(num_workers > 0 && can_schedule_workers())) {
|
||||
slinky::ref_count<schedule_state> state(
|
||||
new schedule_state(num_workers - 1, std::move(task), {this}));
|
||||
threadpool_->Schedule([state = state.take()]() {
|
||||
schedule_workers</*release_impl_ref=*/false>(state);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool release_impl_ref>
|
||||
void YnnThreadpoolImpl::impl::schedule_workers(schedule_state* context) {
|
||||
auto state = slinky::ref_count<schedule_state>::assume(context);
|
||||
|
||||
// We recursively keep scheduling workers into the underlying scheduler.
|
||||
// This is more efficient than scheduling them sequentially from a single
|
||||
// thread, because workers can start processing the task sooner and we
|
||||
// distribute thread wake-ups evenly across underlying threads.
|
||||
static constexpr int32_t kNumRecursiveWorkers = 2;
|
||||
|
||||
for (size_t i = 0; i < kNumRecursiveWorkers; ++i) {
|
||||
bool schedule_worker =
|
||||
state->impl->can_schedule_workers() &&
|
||||
!state->task->is_empty_work_queue() &&
|
||||
state->remaining_workers.fetch_sub(1, std::memory_order_relaxed) > 0;
|
||||
|
||||
if (ABSL_PREDICT_TRUE(!schedule_worker)) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Add +1 reference to account for the scheduled worker, as we use `impl`
|
||||
// reference count to track the number of active workers.
|
||||
state->impl->add_ref();
|
||||
state->impl->threadpool_->Schedule(
|
||||
[state = slinky::ref_count<schedule_state>(state).take()]() {
|
||||
YnnThreadpoolImpl::impl::schedule_workers</*release_impl_ref=*/true>(
|
||||
state);
|
||||
});
|
||||
}
|
||||
|
||||
// Keep processing tasks from the queue until we are out of tasks.
|
||||
static constexpr bool kFalse = false;
|
||||
state->impl->work_on_tasks(absl::Condition(&kFalse));
|
||||
|
||||
// One `impl` reference implicitly owned by the `state`, every additional
|
||||
// reference is added and released explicitly by the worker task.
|
||||
if constexpr (release_impl_ref) {
|
||||
state->impl->release();
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YnnThreadpoolImpl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
YnnThreadpoolImpl::YnnThreadpoolImpl(Eigen::ThreadPoolDevice* device)
|
||||
: impl_(new impl(device ? device->getPool() : nullptr)) {}
|
||||
|
||||
YnnThreadpoolImpl::YnnThreadpoolImpl(Eigen::ThreadPoolInterface* threadpool)
|
||||
: impl_(new impl(threadpool)) {}
|
||||
|
||||
YnnThreadpoolImpl::~YnnThreadpoolImpl() = default;
|
||||
|
||||
slinky::ref_count<YnnThreadpoolImpl::task> YnnThreadpoolImpl::enqueue(
|
||||
size_t n, task_body t, int32_t max_workers) {
|
||||
CHECK_GE(max_workers, n);
|
||||
|
||||
// Don't create more partitions than the number of threads. Also make sure
|
||||
// that we have at least one partition (if we don't have a scheduler).
|
||||
size_t num_partitions = std::min<size_t>(n, thread_count());
|
||||
num_partitions = std::max<size_t>(1, num_partitions);
|
||||
|
||||
auto task = impl_->enqueue(std::move(t), n, num_partitions);
|
||||
|
||||
// If we don't have any worker threads, we return a task to the caller, and
|
||||
// assume that the caller will wait on it.
|
||||
if (ABSL_PREDICT_FALSE(impl_->thread_count() == 0)) {
|
||||
return task;
|
||||
}
|
||||
|
||||
// We assume that the caller will immediately start working on the task, so we
|
||||
// need to schedule workers only for the remaining number of partitions.
|
||||
impl_->schedule_workers(/*num_workers=*/num_partitions - 1, task);
|
||||
|
||||
return task;
|
||||
}
|
||||
|
||||
void YnnThreadpoolImpl::wait_for(task* t) {
|
||||
task_impl* task = static_cast<task_impl*>(t);
|
||||
task_state state = impl_->work_on_task(task);
|
||||
|
||||
// If the task is complete or done, we are immediately done with waiting.
|
||||
if (ABSL_PREDICT_TRUE(state == task_state::kComplete ||
|
||||
state == task_state::kDone)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Switch to the work stealing mode and work on other tasks in the queue
|
||||
// until the given task is done.
|
||||
impl_->work_on_tasks(absl::Condition(task, &task_impl::done));
|
||||
impl_->await(absl::Condition(task, &task_impl::done));
|
||||
}
|
||||
|
||||
void YnnThreadpoolImpl::wait_for(predicate_ref condition) {
|
||||
impl_->work_on_tasks(absl::Condition(&condition));
|
||||
impl_->await(absl::Condition(&condition));
|
||||
}
|
||||
|
||||
void YnnThreadpoolImpl::atomic_call(slinky::function_ref<void()> t) {
|
||||
impl_->atomic_call(t);
|
||||
}
|
||||
|
||||
int YnnThreadpoolImpl::thread_count() const { return impl_->thread_count(); }
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
|
||||
Eigen::ThreadPoolInterface* threadpool) {
|
||||
return CreateYnnThreadpool([&](ynn_threadpool_t* ynn_threadpool) {
|
||||
*ynn_threadpool =
|
||||
reinterpret_cast<ynn_threadpool_t>(new YnnThreadpoolImpl(threadpool));
|
||||
reinterpret_cast<ynn_threadpool_t>(new SlinkyThreadPool(threadpool));
|
||||
return ynn_status_success;
|
||||
});
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user