[xla:cpu:ynn] Implement SlinkyThreadPool on top of WorkQueue and Worker APIs

Remove `work_queue` and `worker` that were originally forked from `xla::cpu::WorkQueue` and `Worker`

PiperOrigin-RevId: 823179793
This commit is contained in:
Eugene Zhulenev 2025-10-23 14:04:30 -07:00 committed by TensorFlower Gardener
parent 96c1b6c0a6
commit 7ba3317857
6 changed files with 535 additions and 605 deletions

View File

@ -55,6 +55,8 @@ class WorkQueue {
size_t num_partitions() const { return partitions_.size(); }
bool IsEmpty() const { return empty_.load(std::memory_order_relaxed); }
private:
friend class Worker;
@ -67,9 +69,8 @@ class WorkQueue {
size_t end;
};
// An empty work queue flag to stop worker threads from looping through all
// partitions looking for work.
bool IsEmpty() const { return empty_.load(std::memory_order_relaxed); }
// Sets an empty work queue flag to stop worker threads from looping through
// all partitions looking for work.
void SetEmpty() { empty_.store(true, std::memory_order_relaxed); }
// Notify that one of the workers switched to the work stealing mode.

View File

@ -14,6 +14,35 @@ package_group(
],
)
cc_library(
name = "slinky_threadpool",
srcs = ["slinky_threadpool.cc"],
hdrs = ["slinky_threadpool.h"],
deps = [
"//xla/backends/cpu/runtime:work_queue",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@eigen_archive//:eigen3",
"@slinky//slinky/base:thread_pool",
],
)
ynn_cc_test(
name = "slinky_threadpool_test",
srcs = ["slinky_threadpool_test.cc"],
deps = [
":slinky_threadpool",
"//xla/tsl/platform:env",
"@com_google_googletest//:gtest_main",
"@slinky//slinky/base:thread_pool",
],
)
cc_library(
name = "ynn_interop",
srcs = ["ynn_interop.cc"],
@ -35,6 +64,7 @@ cc_library(
srcs = ["ynn_threadpool.cc"],
hdrs = ["ynn_threadpool.h"],
deps = [
":slinky_threadpool",
":ynn_interop",
"@XNNPACK//ynnpack:ynnpack_h",
"@com_google_absl//absl/algorithm:container",
@ -49,17 +79,6 @@ cc_library(
],
)
ynn_cc_test(
name = "ynn_threadpool_test",
srcs = ["ynn_threadpool_test.cc"],
deps = [
":ynn_threadpool",
"//xla/tsl/platform:env",
"@com_google_googletest//:gtest_main",
"@slinky//slinky/base:thread_pool",
],
)
cc_library(
name = "ynn_fusion_thunk",
srcs = ["ynn_fusion_thunk.cc"],

View File

@ -0,0 +1,416 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/runtime/ynnpack/slinky_threadpool.h"
#include <algorithm>
#include <atomic>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <deque>
#include <optional>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/base/optimization.h"
#include "absl/base/thread_annotations.h"
#include "absl/log/check.h"
#include "absl/synchronization/mutex.h"
#include "slinky/base/function_ref.h"
#include "slinky/base/ref_count.h"
#include "slinky/base/thread_pool.h"
#include "xla/backends/cpu/runtime/work_queue.h"
#define EIGEN_USE_THREADS
#include "Eigen/ThreadPool"
#include "unsupported/Eigen/CXX11/Tensor"
namespace xla::cpu {
//===----------------------------------------------------------------------===//
// Task
//===----------------------------------------------------------------------===//
namespace {
// Running a task can result in three states:
//
// kPending: The task is still being processed by the worker threads.
// kComplete: The caller thread is the one who completed the task.
// kDone: The task is done and all work items have been processed, however
// the caller thread did't process any work items.
//
// We need this state to signal the waiter thread just once, from a thread that
// completed the task.S
enum class TaskState { kPending, kComplete, kDone };
class Task final : public SlinkyThreadPool::task {
public:
Task(SlinkyThreadPool::task_body body, size_t num_work_items,
size_t num_partitions);
// Runs this task by processing work items in the current thread.
TaskState Run();
// Returns the number of workers that are currently working on this task.
int64_t num_workers() const;
bool is_empty_work_queue() const;
bool done() const final;
private:
SlinkyThreadPool::task_body body_;
WorkQueue work_queue_;
ABSL_CACHELINE_ALIGNED std::atomic<size_t> worker_index_;
ABSL_CACHELINE_ALIGNED std::atomic<size_t> pending_work_items_;
};
} // namespace
Task::Task(SlinkyThreadPool::task_body body, size_t num_work_items,
size_t num_partitions)
: body_(std::move(body)),
work_queue_(num_work_items, num_partitions),
worker_index_(0),
pending_work_items_(num_work_items) {}
TaskState Task::Run() {
// If we have more workers joining the task than the number of partitions,
// then we have to wrap around to the first partition.
size_t worker_index = worker_index_.fetch_add(1, std::memory_order_relaxed);
if (ABSL_PREDICT_FALSE(worker_index >= work_queue_.num_partitions())) {
worker_index %= work_queue_.num_partitions();
}
// Each worker processes the body using its own copy of the task.
Worker w(worker_index, &work_queue_);
size_t num_processed_work_items = 0;
if (std::optional<size_t> item = w.Pop()) {
SlinkyThreadPool::task_body body = body_;
do {
body(*item);
++num_processed_work_items;
} while ((item = w.Pop()).has_value());
}
// The number of pending work items should never go below zero.
size_t previous_work_items = pending_work_items_.fetch_sub(
num_processed_work_items, std::memory_order_acq_rel);
DCHECK_GE(previous_work_items, num_processed_work_items);
// Task is done if we have no more work items to process. Task is complete if
// we are the one who processed the last work item.
bool is_done = previous_work_items == num_processed_work_items;
bool is_complete = is_done && num_processed_work_items > 0;
return is_complete ? TaskState::kComplete
: is_done ? TaskState::kDone
: TaskState::kPending;
}
int64_t Task::num_workers() const {
return worker_index_.load(std::memory_order_relaxed);
}
bool Task::is_empty_work_queue() const { return work_queue_.IsEmpty(); }
bool Task::done() const {
return pending_work_items_.load(std::memory_order_acquire) == 0;
}
//===----------------------------------------------------------------------===//
// SlinkyThreadPool::Impl
//===----------------------------------------------------------------------===//
// We keep a stack of tasks that are currently being processed by current
// thread, to avoid recursive calls.
static thread_local std::vector<const Task*> task_stack; // NOLINT
class SlinkyThreadPool::Impl : public slinky::ref_counted<Impl> {
public:
explicit Impl(Eigen::ThreadPoolInterface* threadpool);
// Enqueues a new task into the queue and returns a reference to it.
slinky::ref_count<Task> Enqueue(SlinkyThreadPool::task_body body,
size_t num_work_items, size_t num_partitions);
// Work on the single task and return the state of the task.
TaskState WorkOnTask(Task* task);
// Work on all tasks in the queue. Returns when Run out of tasks to process.
void WorkOnTasks(const absl::Condition& condition);
void Await(const absl::Condition& condition);
void AtomicCall(slinky::function_ref<void()> t);
// Returns true if we can schedule more workers into the underlying scheduler.
bool CanScheduleWorkers() const;
// Schedules the given number of workers for the given task. Worker scheduling
// uses recursive work splitting and early exit if the task does not need any
// more workers, of if we reached the maximum number of scheduled workers.
void ScheduleWorkers(int64_t num_workers, slinky::ref_count<Task> task);
size_t thread_count() const { return thread_count_; }
private:
friend class slinky::ref_counted<Impl>;
static void destroy(Impl* ptr) { delete ptr; }
// A state of the work scheduling for a given task.
struct ScheduleState : public slinky::ref_counted<ScheduleState> {
ScheduleState(int64_t remaining_workers, slinky::ref_count<Task> task,
slinky::ref_count<Impl> impl)
: remaining_workers(remaining_workers),
task(std::move(task)),
impl(std::move(impl)) {}
static void destroy(ScheduleState* ptr) { delete ptr; }
std::atomic<int64_t> remaining_workers;
slinky::ref_count<Task> task;
slinky::ref_count<Impl> impl;
};
// Worker scheduling function for the underlying scheduler.
template <bool release_impl_ref>
static void ScheduleWorkers(ScheduleState* context);
// Dequeues a pending task from the queue.
slinky::ref_count<Task> Dequeue();
// Signals all waiter threads waiting on the waiter mutex.
void SignalWaiters();
Eigen::ThreadPoolInterface* threadpool_;
size_t thread_count_;
std::deque<slinky::ref_count<Task>> tasks_ ABSL_GUARDED_BY(tasks_mutex_);
// A mutex for guarding mutable state accessed concurrently.
ABSL_CACHELINE_ALIGNED absl::Mutex tasks_mutex_;
// A mutex for signalling threads waiting on the tasks or conditions.
ABSL_CACHELINE_ALIGNED absl::Mutex waiter_mutex_;
};
SlinkyThreadPool::Impl::Impl(Eigen::ThreadPoolInterface* threadpool)
: threadpool_(threadpool),
thread_count_(threadpool_ ? threadpool_->NumThreads() : 0) {}
slinky::ref_count<Task> SlinkyThreadPool::Impl::Enqueue(
SlinkyThreadPool::task_body body, size_t num_work_items,
size_t num_partitions) {
slinky::ref_count<Task> task(
new Task(std::move(body), num_work_items, num_partitions));
absl::MutexLock lock(tasks_mutex_);
return tasks_.emplace_back(std::move(task));
}
slinky::ref_count<Task> SlinkyThreadPool::Impl::Dequeue() {
absl::MutexLock lock(tasks_mutex_);
for (auto i = tasks_.begin(); i != tasks_.end();) {
slinky::ref_count<Task>& task = *i;
// Task doesn't have any more work items to process.
if (ABSL_PREDICT_FALSE(task->is_empty_work_queue())) {
i = tasks_.erase(i);
continue;
}
// Don't Run the same task multiple times on the same thread.
if (ABSL_PREDICT_FALSE(absl::c_contains(task_stack, &*task))) {
++i;
continue;
}
return task;
}
return nullptr;
}
TaskState SlinkyThreadPool::Impl::WorkOnTask(Task* task) {
DCHECK(absl::c_find(task_stack, task) == task_stack.end());
task_stack.push_back(task);
TaskState state = task->Run();
task_stack.pop_back();
// If we are the one who completed the task, we signal the waiters to wake upS
// any threads that are waiting for the task completion. If the task was
// completed by another worker, we do nothing to avoid the cost of waking up
// the same thread multiple times.
if (ABSL_PREDICT_FALSE(state == TaskState::kComplete)) {
SignalWaiters();
}
return state;
}
void SlinkyThreadPool::Impl::WorkOnTasks(const absl::Condition& condition) {
while (slinky::ref_count<Task> task = Dequeue()) {
WorkOnTask(&*task);
if (ABSL_PREDICT_TRUE(condition.Eval())) {
return;
}
}
}
void SlinkyThreadPool::Impl::Await(const absl::Condition& condition) {
if (ABSL_PREDICT_FALSE(!condition.Eval())) {
absl::MutexLock lock(waiter_mutex_);
waiter_mutex_.Await(condition);
}
}
void SlinkyThreadPool::Impl::SignalWaiters() {
absl::MutexLock lock(waiter_mutex_);
}
void SlinkyThreadPool::Impl::AtomicCall(slinky::function_ref<void()> t) {
absl::MutexLock lock(waiter_mutex_);
t();
}
bool SlinkyThreadPool::Impl::CanScheduleWorkers() const {
// One reference is owned by the parent SlinkyThreadPool, every other
// reference is owned by a worker scheduled into the underlying scheduler.
return ref_count() < 1 + thread_count();
}
void SlinkyThreadPool::Impl::ScheduleWorkers(int64_t num_workers,
slinky::ref_count<Task> task) {
if (ABSL_PREDICT_TRUE(num_workers > 0 && CanScheduleWorkers())) {
slinky::ref_count<ScheduleState> state(
new ScheduleState(num_workers - 1, std::move(task), {this}));
threadpool_->Schedule([state = state.take()]() {
ScheduleWorkers</*release_impl_ref=*/false>(state);
});
}
}
template <bool release_impl_ref>
void SlinkyThreadPool::Impl::ScheduleWorkers(ScheduleState* context) {
auto state = slinky::ref_count<ScheduleState>::assume(context);
// We recursively keep scheduling workers into the underlying scheduler.
// This is more efficient than scheduling them sequentially from a single
// thread, because workers can start processing the task sooner and we
// distribute thread wake-ups evenly across underlying threads.
static constexpr int32_t kNumRecursiveWorkers = 2;
for (size_t i = 0; i < kNumRecursiveWorkers; ++i) {
bool schedule_worker =
state->impl->CanScheduleWorkers() &&
!state->task->is_empty_work_queue() &&
state->remaining_workers.fetch_sub(1, std::memory_order_relaxed) > 0;
if (ABSL_PREDICT_TRUE(!schedule_worker)) {
break;
}
// Add +1 reference to account for the scheduled worker, as we use `impl`
// reference count to track the number of active workers.
state->impl->add_ref();
state->impl->threadpool_->Schedule(
[state = slinky::ref_count<ScheduleState>(state).take()]() {
SlinkyThreadPool::Impl::ScheduleWorkers</*release_impl_ref=*/true>(
state);
});
}
// Keep processing tasks from the queue until we are out of tasks.
static constexpr bool kFalse = false;
state->impl->WorkOnTasks(absl::Condition(&kFalse));
// One `impl` reference implicitly owned by the `state`, every additional
// reference is added and released explicitly by the worker task.
if constexpr (release_impl_ref) {
state->impl->release();
}
}
//===----------------------------------------------------------------------===//
// SlinkyThreadPool
//===----------------------------------------------------------------------===//
SlinkyThreadPool::SlinkyThreadPool(Eigen::ThreadPoolDevice* device)
: impl_(new Impl(device ? device->getPool() : nullptr)) {}
SlinkyThreadPool::SlinkyThreadPool(Eigen::ThreadPoolInterface* threadpool)
: impl_(new Impl(threadpool)) {}
SlinkyThreadPool::~SlinkyThreadPool() = default;
slinky::ref_count<SlinkyThreadPool::task> SlinkyThreadPool::enqueue(
size_t n, task_body t, int32_t max_workers) {
CHECK_GE(max_workers, n);
// Don't create more partitions than the number of threads. Also make sure
// that we have at least one partition (if we don't have a scheduler).
size_t num_partitions = std::min<size_t>(n, thread_count());
num_partitions = std::max<size_t>(1, num_partitions);
auto task = impl_->Enqueue(std::move(t), n, num_partitions);
// If we don't have any worker threads, we return a task to the caller, and
// assume that the caller will wait on it.
if (ABSL_PREDICT_FALSE(impl_->thread_count() == 0)) {
return task;
}
// We assume that the caller will immediately start working on the task, so we
// need to schedule workers only for the remaining number of partitions.
impl_->ScheduleWorkers(/*num_workers=*/num_partitions - 1, task);
return task;
}
void SlinkyThreadPool::wait_for(task* t) {
Task* task = static_cast<Task*>(t);
TaskState state = impl_->WorkOnTask(task);
// If the task is complete or done, we are immediately done with waiting.
if (ABSL_PREDICT_TRUE(state == TaskState::kComplete ||
state == TaskState::kDone)) {
return;
}
// Switch to the work stealing mode and work on other tasks in the queue
// until the given task is done.
impl_->WorkOnTasks(absl::Condition(task, &Task::done));
impl_->Await(absl::Condition(task, &Task::done));
}
void SlinkyThreadPool::wait_for(predicate_ref condition) {
impl_->WorkOnTasks(absl::Condition(&condition));
impl_->Await(absl::Condition(&condition));
}
void SlinkyThreadPool::atomic_call(slinky::function_ref<void()> t) {
impl_->AtomicCall(t);
}
int SlinkyThreadPool::thread_count() const { return impl_->thread_count(); }
} // namespace xla::cpu

View File

@ -0,0 +1,61 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_BACKENDS_CPU_RUNTIME_YNNPACK_SLINKY_THREADPOOL_H_
#define XLA_BACKENDS_CPU_RUNTIME_YNNPACK_SLINKY_THREADPOOL_H_
#include <cstddef>
#include <cstdint>
#include "slinky/base/function_ref.h"
#include "slinky/base/ref_count.h"
#include "slinky/base/thread_pool.h"
namespace Eigen {
struct ThreadPoolDevice;
class ThreadPoolInterface;
} // namespace Eigen
namespace xla::cpu {
// This is an implementation of slinky::thread_pool, using absl::Mutex for
// synchronization, and dispatches work to Eigen::ThreadPoolInterface.
class SlinkyThreadPool final : public slinky::thread_pool {
public:
explicit SlinkyThreadPool(Eigen::ThreadPoolDevice* device);
explicit SlinkyThreadPool(Eigen::ThreadPoolInterface* threadpool);
~SlinkyThreadPool() final;
SlinkyThreadPool(SlinkyThreadPool&&) = default;
SlinkyThreadPool& operator=(SlinkyThreadPool&&) = default;
slinky::ref_count<task> enqueue(size_t n, task_body t,
int32_t max_workers) final;
void wait_for(task* t) final;
void wait_for(predicate_ref condition) final;
void atomic_call(slinky::function_ref<void()> t) final;
int thread_count() const final;
private:
class Impl;
slinky::ref_count<Impl> impl_;
};
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_RUNTIME_YNNPACK_SLINKY_THREADPOOL_H_

View File

@ -13,95 +13,81 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
#include "xla/backends/cpu/runtime/ynnpack/slinky_threadpool.h"
#include <array>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <vector>
#include <gtest/gtest.h>
#include "slinky/base/thread_pool.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/threadpool.h"
namespace Eigen {
class ThreadPoolInterface;
} // namespace Eigen
namespace xla::cpu {
TEST(YnnThreadpoolImpl, inline_scheduling) {
auto ynn_threadpool =
CreateYnnThreadpool(static_cast<Eigen::ThreadPoolInterface*>(nullptr));
auto thread_pool =
reinterpret_cast<slinky::thread_pool*>(ynn_threadpool->get());
TEST(SlinkyThreadPoolTest, InlineScheduling) {
SlinkyThreadPool thread_pool(
static_cast<Eigen::ThreadPoolInterface*>(nullptr));
static constexpr size_t size = 10000;
std::vector<int32_t> data(size, 0);
auto inc = [&](size_t i) { data[i]++; };
thread_pool->parallel_for(size, inc);
thread_pool.parallel_for(size, inc);
std::vector<int32_t> expected(size, 1);
EXPECT_EQ(data, expected);
}
TEST(YnnThreadpoolImpl, single_loop) {
TEST(SlinkyThreadPoolTest, SingleLoop) {
tsl::thread::ThreadPool test_thread_pool(tsl::Env::Default(), "test", 4);
auto ynn_threadpool =
CreateYnnThreadpool(test_thread_pool.AsEigenThreadPool());
auto thread_pool =
reinterpret_cast<slinky::thread_pool*>(ynn_threadpool->get());
SlinkyThreadPool thread_pool(test_thread_pool.AsEigenThreadPool());
static constexpr size_t size = 10000;
std::vector<int32_t> data(size, 0);
auto inc = [&](size_t i) { data[i]++; };
thread_pool->parallel_for(size, inc);
thread_pool.parallel_for(size, inc);
std::vector<int32_t> expected(size, 1);
EXPECT_EQ(data, expected);
}
TEST(YnnThreadpoolImpl, loop_chain) {
TEST(SlinkyThreadPoolTest, LoopChain) {
tsl::thread::ThreadPool test_thread_pool(tsl::Env::Default(), "test", 4);
auto ynn_threadpool =
CreateYnnThreadpool(test_thread_pool.AsEigenThreadPool());
auto thread_pool =
reinterpret_cast<slinky::thread_pool*>(ynn_threadpool->get());
SlinkyThreadPool thread_pool(test_thread_pool.AsEigenThreadPool());
static constexpr size_t size = 10000;
std::vector<int32_t> data(size, 0);
auto inc = [&](size_t i) { data[i]++; };
thread_pool->parallel_for(size, inc);
thread_pool->parallel_for(size, inc);
thread_pool->parallel_for(size, inc);
thread_pool->parallel_for(size, inc);
thread_pool->parallel_for(size, inc);
thread_pool.parallel_for(size, inc);
thread_pool.parallel_for(size, inc);
thread_pool.parallel_for(size, inc);
thread_pool.parallel_for(size, inc);
thread_pool.parallel_for(size, inc);
std::vector<int32_t> expected(size, 5);
EXPECT_EQ(data, expected);
}
TEST(YnnThreadpoolImpl, nested_loops) {
TEST(SlinkyThreadPoolTest, NestedLoops) {
tsl::thread::ThreadPool test_thread_pool(tsl::Env::Default(), "test", 4);
auto ynn_threadpool =
CreateYnnThreadpool(test_thread_pool.AsEigenThreadPool());
auto thread_pool =
reinterpret_cast<slinky::thread_pool*>(ynn_threadpool->get());
SlinkyThreadPool thread_pool(test_thread_pool.AsEigenThreadPool());
static constexpr size_t size = 100;
std::array<std::atomic<int32_t>, size> data = {{0}};
auto inc = [&](size_t i) { data[i]++; };
thread_pool->parallel_for(
size, [&](size_t i) { thread_pool->parallel_for(size, inc); });
thread_pool.parallel_for(
size, [&](size_t i) { thread_pool.parallel_for(size, inc); });
for (size_t i = 0; i < size; ++i) {
EXPECT_EQ(data[i], size);

View File

@ -13,29 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <atomic>
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <deque>
#include <optional>
#include <utility>
#include <vector>
#include "ynnpack/include/ynnpack.h"
#include "absl/algorithm/container.h"
#include "absl/base/optimization.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/fixed_array.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/synchronization/mutex.h"
#include "slinky/base/function_ref.h"
#include "slinky/base/ref_count.h"
#include "slinky/base/thread_pool.h"
#include "xla/backends/cpu/runtime/ynnpack/slinky_threadpool.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
#define EIGEN_USE_THREADS
#include "Eigen/ThreadPool"
@ -43,550 +29,11 @@ limitations under the License.
namespace xla::cpu {
namespace {
// This is an implementation of slinky::thread_pool, using absl::Mutex for
// synchronization, and dispatches work to Eigen::ThreadPoolInterface.
class YnnThreadpoolImpl final : public slinky::thread_pool {
public:
explicit YnnThreadpoolImpl(Eigen::ThreadPoolDevice* device);
explicit YnnThreadpoolImpl(Eigen::ThreadPoolInterface* threadpool);
~YnnThreadpoolImpl() final;
YnnThreadpoolImpl(YnnThreadpoolImpl&&) = delete;
YnnThreadpoolImpl& operator=(YnnThreadpoolImpl&&) = delete;
slinky::ref_count<task> enqueue(size_t n, task_body t,
int32_t max_workers) final;
void wait_for(task* t) final;
void wait_for(predicate_ref condition) final;
void atomic_call(slinky::function_ref<void()> t) final;
int thread_count() const final;
private:
class impl;
slinky::ref_count<impl> impl_;
};
//===----------------------------------------------------------------------===//
// work_queue
//===----------------------------------------------------------------------===//
// Forward declare.
class worker;
// A work queue that partitions `num_work_items` work items into
// `num_partitions` partitions processed by parallel workers.
class work_queue {
public:
work_queue(size_t num_work_items, size_t num_partitions);
// Returns the next work item in the given partition. Returns std::nullopt
// if the partition is complete.
std::optional<size_t> pop_work_item(size_t partition_index);
// Return the partition [begin, end) work item range.
std::pair<size_t, size_t> partition_range(size_t partition_index) const;
size_t num_partitions() const { return partitions_.size(); }
// If work queue is empty, it means that all work items are being processed by
// the workers, and the task will be done once all workers complete.
bool is_empty() const { return empty_.load(std::memory_order_relaxed); }
private:
friend class worker;
// Work items partition tracking the next work item to process.
struct partition {
void initialize(size_t begin, size_t end);
// Tracks index of the next work item in the assigned partition.
ABSL_CACHELINE_ALIGNED std::atomic<size_t> index;
size_t begin;
size_t end;
};
void set_empty() { empty_.store(true, std::memory_order_relaxed); }
absl::FixedArray<partition, 32> partitions_;
ABSL_CACHELINE_ALIGNED std::atomic<bool> empty_;
};
} // namespace
void work_queue::partition::initialize(size_t begin, size_t end) {
index.store(begin, std::memory_order_relaxed);
this->begin = begin;
this->end = end;
}
work_queue::work_queue(size_t num_work_items, size_t num_partitions)
: partitions_(num_partitions), empty_(num_work_items == 0) {
size_t partition_size = num_work_items / num_partitions;
size_t rem_work = num_work_items % num_partitions;
for (size_t i = 0, begin = 0, end = 0; i < num_partitions; ++i, begin = end) {
end = begin + partition_size + ((i < rem_work) ? 1 : 0);
partitions_[i].initialize(begin, end);
}
}
std::optional<size_t> work_queue::pop_work_item(size_t partition_index) {
DCHECK(partition_index < partitions_.size()) << "Invalid partition index";
partition& partition = partitions_.data()[partition_index];
// Check if partition is already empty.
if (size_t index = partition.index.load(std::memory_order_relaxed);
ABSL_PREDICT_FALSE(index >= partition.end)) {
return std::nullopt;
}
// Try to acquire the next work item in the partition.
size_t index = partition.index.fetch_add(1, std::memory_order_relaxed);
return ABSL_PREDICT_FALSE(index >= partition.end) ? std::nullopt
: std::make_optional(index);
}
std::pair<size_t, size_t> work_queue::partition_range(
size_t partition_index) const {
DCHECK(partition_index < partitions_.size()) << "Invalid partition index";
return {partitions_[partition_index].begin, partitions_[partition_index].end};
}
//===----------------------------------------------------------------------===//
// worker
//===----------------------------------------------------------------------===//
namespace {
// Worker processes work items from the work queue starting from the assigned
// work partition. Once the assigned partition is complete it tries to pop
// the work item from the next partition. Once the work queue is empty (the
// worker wraps around to the initial partition) it returns and empty work item.
class worker {
public:
worker(size_t partition_index, work_queue* queue);
std::optional<size_t> pop_work_item();
private:
size_t initial_partition_index_;
size_t partition_index_;
work_queue* queue_;
};
} // namespace
worker::worker(size_t partition_index, work_queue* queue)
: initial_partition_index_(partition_index),
partition_index_(partition_index),
queue_(queue) {}
std::optional<size_t> worker::pop_work_item() {
std::optional<size_t> work = queue_->pop_work_item(partition_index_);
if (ABSL_PREDICT_TRUE(work)) {
return work;
}
while (!work.has_value() && !queue_->is_empty()) {
// Wrap around to the first partition.
if (ABSL_PREDICT_FALSE(++partition_index_ >= queue_->num_partitions())) {
partition_index_ = 0;
}
// We checked all partitions and got back to the partition we started from.
if (ABSL_PREDICT_FALSE(partition_index_ == initial_partition_index_)) {
queue_->set_empty();
break;
}
work = queue_->pop_work_item(partition_index_);
}
return work;
}
//===----------------------------------------------------------------------===//
// task_impl
//===----------------------------------------------------------------------===//
namespace {
// Running a task can result in three states:
//
// kPending: The task is still being processed by the worker threads.
// kComplete: The caller thread is the one who completed the task.
// kDone: The task is done and all work items have been processed, however
// the caller thread did't process any work items.
//
// We need this state to signal the waiter thread just once, from a thread that
// completed the task.
enum class task_state { kPending, kComplete, kDone };
class task_impl final : public YnnThreadpoolImpl::task {
public:
task_impl(YnnThreadpoolImpl::task_body body, size_t num_work_items,
size_t num_partitions);
// Runs this task by process work items in the current thread.
task_state run();
int64_t num_workers() const;
bool is_empty_work_queue() const;
bool done() const final;
private:
YnnThreadpoolImpl::task_body body_;
work_queue work_queue_;
ABSL_CACHELINE_ALIGNED std::atomic<size_t> worker_index_;
ABSL_CACHELINE_ALIGNED std::atomic<size_t> pending_work_items_;
};
task_impl::task_impl(YnnThreadpoolImpl::task_body body, size_t num_work_items,
size_t num_partitions)
: body_(std::move(body)),
work_queue_(num_work_items, num_partitions),
worker_index_(0),
pending_work_items_(num_work_items) {}
task_state task_impl::run() {
// If we have more workers joining the task than the number of partitions,
// then we have to wrap around to the first partition.
size_t worker_index = worker_index_.fetch_add(1, std::memory_order_relaxed);
if (ABSL_PREDICT_FALSE(worker_index >= work_queue_.num_partitions())) {
worker_index %= work_queue_.num_partitions();
}
// Each worker processes the body using its own copy of the task.
worker w(worker_index, &work_queue_);
size_t num_processed_work_items = 0;
if (std::optional<size_t> item = w.pop_work_item()) {
YnnThreadpoolImpl::task_body body = body_;
do {
body(*item);
++num_processed_work_items;
} while ((item = w.pop_work_item()).has_value());
}
// The number of pending work items should never go below zero.
size_t previous_work_items = pending_work_items_.fetch_sub(
num_processed_work_items, std::memory_order_acq_rel);
DCHECK_GE(previous_work_items, num_processed_work_items);
// Task is done if we have no more work items to process. Task is complete if
// we are the one who processed the last work item.
bool is_done = previous_work_items == num_processed_work_items;
bool is_complete = is_done && num_processed_work_items > 0;
return is_complete ? task_state::kComplete
: is_done ? task_state::kDone
: task_state::kPending;
}
int64_t task_impl::num_workers() const {
return worker_index_.load(std::memory_order_relaxed);
}
bool task_impl::is_empty_work_queue() const { return work_queue_.is_empty(); }
bool task_impl::done() const {
return pending_work_items_.load(std::memory_order_acquire) == 0;
}
//===----------------------------------------------------------------------===//
// YnnThreadpoolImpl::impl
//===----------------------------------------------------------------------===//
// We keep a stack of tasks that are currently being processed by current
// thread, to avoid recursive calls.
static thread_local std::vector<const task_impl*> task_stack; // NOLINT
class YnnThreadpoolImpl::impl : public slinky::ref_counted<impl> {
public:
explicit impl(Eigen::ThreadPoolInterface* threadpool);
// Work on the single task and return the state of the task.
task_state work_on_task(task_impl* task);
// Work on all tasks in the queue. Returns when run out of tasks to process.
void work_on_tasks(const absl::Condition& condition);
// Enqueues a new task into the queue and returns a reference to it.
slinky::ref_count<task_impl> enqueue(YnnThreadpoolImpl::task_body body,
size_t num_work_items,
size_t num_partitions);
void await(const absl::Condition& condition);
void atomic_call(slinky::function_ref<void()> t);
// Returns true if we can schedule more workers into the underlying scheduler.
bool can_schedule_workers() const;
// Schedules the given number of workers for the given task. Worker scheduling
// uses recursive work splitting and early exit if the task does not need any
// more workers, of if we reached the maximum number of scheduled workers.
void schedule_workers(int64_t num_workers, slinky::ref_count<task_impl> task);
size_t thread_count() const { return thread_count_; }
private:
friend class slinky::ref_counted<impl>;
static void destroy(impl* ptr) { delete ptr; }
// A state of the work scheduling for a given task.
struct schedule_state : public slinky::ref_counted<schedule_state> {
schedule_state(int64_t remaining_workers, slinky::ref_count<task_impl> task,
slinky::ref_count<impl> impl)
: remaining_workers(remaining_workers),
task(std::move(task)),
impl(std::move(impl)) {}
static void destroy(schedule_state* ptr) { delete ptr; }
std::atomic<int64_t> remaining_workers;
slinky::ref_count<task_impl> task;
slinky::ref_count<impl> impl;
};
// Worker scheduling function for the underlying scheduler.
template <bool release_impl_ref>
static void schedule_workers(schedule_state* context);
// Dequeues a pending task from the queue.
slinky::ref_count<task_impl> dequeue();
// Signals all waiter threads waiting on the waiter mutex.
void signal_waiters();
Eigen::ThreadPoolInterface* threadpool_;
size_t thread_count_;
std::deque<slinky::ref_count<task_impl>> tasks_ ABSL_GUARDED_BY(tasks_mutex_);
// A mutex for guarding mutable state accessed concurrently.
ABSL_CACHELINE_ALIGNED absl::Mutex tasks_mutex_;
// A mutex for signalling threads waiting on the tasks or conditions.
ABSL_CACHELINE_ALIGNED absl::Mutex waiter_mutex_;
};
YnnThreadpoolImpl::impl::impl(Eigen::ThreadPoolInterface* threadpool)
: threadpool_(threadpool),
thread_count_(threadpool_ ? threadpool_->NumThreads() : 0) {}
slinky::ref_count<task_impl> YnnThreadpoolImpl::impl::enqueue(
YnnThreadpoolImpl::task_body body, size_t num_work_items,
size_t num_partitions) {
slinky::ref_count<task_impl> task(
new task_impl(std::move(body), num_work_items, num_partitions));
absl::MutexLock lock(tasks_mutex_);
return tasks_.emplace_back(std::move(task));
}
slinky::ref_count<task_impl> YnnThreadpoolImpl::impl::dequeue() {
absl::MutexLock lock(tasks_mutex_);
for (auto i = tasks_.begin(); i != tasks_.end();) {
slinky::ref_count<task_impl>& task = *i;
// Task doesn't have any more work items to process.
if (ABSL_PREDICT_FALSE(task->is_empty_work_queue())) {
i = tasks_.erase(i);
continue;
}
// Don't run the same task multiple times on the same thread.
if (ABSL_PREDICT_FALSE(absl::c_contains(task_stack, &*task))) {
++i;
continue;
}
return task;
}
return nullptr;
}
task_state YnnThreadpoolImpl::impl::work_on_task(task_impl* task) {
DCHECK(absl::c_find(task_stack, task) == task_stack.end());
task_stack.push_back(task);
task_state state = task->run();
task_stack.pop_back();
// If we are the one who completed the task, we signal the waiters to wake upS
// any threads that are waiting for the task completion. If the task was
// completed by another worker, we do nothing to avoid the cost of waking up
// the same thread multiple times.
if (ABSL_PREDICT_FALSE(state == task_state::kComplete)) {
signal_waiters();
}
return state;
}
void YnnThreadpoolImpl::impl::work_on_tasks(const absl::Condition& condition) {
while (slinky::ref_count<task_impl> task = dequeue()) {
work_on_task(&*task);
if (ABSL_PREDICT_TRUE(condition.Eval())) {
return;
}
}
}
void YnnThreadpoolImpl::impl::await(const absl::Condition& condition) {
if (ABSL_PREDICT_FALSE(!condition.Eval())) {
absl::MutexLock lock(waiter_mutex_);
waiter_mutex_.Await(condition);
}
}
void YnnThreadpoolImpl::impl::signal_waiters() {
absl::MutexLock lock(waiter_mutex_);
}
void YnnThreadpoolImpl::impl::atomic_call(slinky::function_ref<void()> t) {
absl::MutexLock lock(waiter_mutex_);
t();
}
bool YnnThreadpoolImpl::impl::can_schedule_workers() const {
// One reference is owned by the parent YnnThreadpoolImpl, every other
// reference is owned by a worker scheduled into the underlying scheduler.
return ref_count() < 1 + thread_count();
}
void YnnThreadpoolImpl::impl::schedule_workers(
int64_t num_workers, slinky::ref_count<task_impl> task) {
if (ABSL_PREDICT_TRUE(num_workers > 0 && can_schedule_workers())) {
slinky::ref_count<schedule_state> state(
new schedule_state(num_workers - 1, std::move(task), {this}));
threadpool_->Schedule([state = state.take()]() {
schedule_workers</*release_impl_ref=*/false>(state);
});
}
}
template <bool release_impl_ref>
void YnnThreadpoolImpl::impl::schedule_workers(schedule_state* context) {
auto state = slinky::ref_count<schedule_state>::assume(context);
// We recursively keep scheduling workers into the underlying scheduler.
// This is more efficient than scheduling them sequentially from a single
// thread, because workers can start processing the task sooner and we
// distribute thread wake-ups evenly across underlying threads.
static constexpr int32_t kNumRecursiveWorkers = 2;
for (size_t i = 0; i < kNumRecursiveWorkers; ++i) {
bool schedule_worker =
state->impl->can_schedule_workers() &&
!state->task->is_empty_work_queue() &&
state->remaining_workers.fetch_sub(1, std::memory_order_relaxed) > 0;
if (ABSL_PREDICT_TRUE(!schedule_worker)) {
break;
}
// Add +1 reference to account for the scheduled worker, as we use `impl`
// reference count to track the number of active workers.
state->impl->add_ref();
state->impl->threadpool_->Schedule(
[state = slinky::ref_count<schedule_state>(state).take()]() {
YnnThreadpoolImpl::impl::schedule_workers</*release_impl_ref=*/true>(
state);
});
}
// Keep processing tasks from the queue until we are out of tasks.
static constexpr bool kFalse = false;
state->impl->work_on_tasks(absl::Condition(&kFalse));
// One `impl` reference implicitly owned by the `state`, every additional
// reference is added and released explicitly by the worker task.
if constexpr (release_impl_ref) {
state->impl->release();
}
}
//===----------------------------------------------------------------------===//
// YnnThreadpoolImpl
//===----------------------------------------------------------------------===//
YnnThreadpoolImpl::YnnThreadpoolImpl(Eigen::ThreadPoolDevice* device)
: impl_(new impl(device ? device->getPool() : nullptr)) {}
YnnThreadpoolImpl::YnnThreadpoolImpl(Eigen::ThreadPoolInterface* threadpool)
: impl_(new impl(threadpool)) {}
YnnThreadpoolImpl::~YnnThreadpoolImpl() = default;
slinky::ref_count<YnnThreadpoolImpl::task> YnnThreadpoolImpl::enqueue(
size_t n, task_body t, int32_t max_workers) {
CHECK_GE(max_workers, n);
// Don't create more partitions than the number of threads. Also make sure
// that we have at least one partition (if we don't have a scheduler).
size_t num_partitions = std::min<size_t>(n, thread_count());
num_partitions = std::max<size_t>(1, num_partitions);
auto task = impl_->enqueue(std::move(t), n, num_partitions);
// If we don't have any worker threads, we return a task to the caller, and
// assume that the caller will wait on it.
if (ABSL_PREDICT_FALSE(impl_->thread_count() == 0)) {
return task;
}
// We assume that the caller will immediately start working on the task, so we
// need to schedule workers only for the remaining number of partitions.
impl_->schedule_workers(/*num_workers=*/num_partitions - 1, task);
return task;
}
void YnnThreadpoolImpl::wait_for(task* t) {
task_impl* task = static_cast<task_impl*>(t);
task_state state = impl_->work_on_task(task);
// If the task is complete or done, we are immediately done with waiting.
if (ABSL_PREDICT_TRUE(state == task_state::kComplete ||
state == task_state::kDone)) {
return;
}
// Switch to the work stealing mode and work on other tasks in the queue
// until the given task is done.
impl_->work_on_tasks(absl::Condition(task, &task_impl::done));
impl_->await(absl::Condition(task, &task_impl::done));
}
void YnnThreadpoolImpl::wait_for(predicate_ref condition) {
impl_->work_on_tasks(absl::Condition(&condition));
impl_->await(absl::Condition(&condition));
}
void YnnThreadpoolImpl::atomic_call(slinky::function_ref<void()> t) {
impl_->atomic_call(t);
}
int YnnThreadpoolImpl::thread_count() const { return impl_->thread_count(); }
} // namespace
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
Eigen::ThreadPoolInterface* threadpool) {
return CreateYnnThreadpool([&](ynn_threadpool_t* ynn_threadpool) {
*ynn_threadpool =
reinterpret_cast<ynn_threadpool_t>(new YnnThreadpoolImpl(threadpool));
reinterpret_cast<ynn_threadpool_t>(new SlinkyThreadPool(threadpool));
return ynn_status_success;
});
}