[xla:cpu] Use work_item instead of a task in WorkQueue/Worker API

To avoid confusion because of different kinds of tasks we have in Worker/WorkQueue and a SlinklyThreadPool in XLA use a more generic "work item" name.

PiperOrigin-RevId: 823191886
This commit is contained in:
Eugene Zhulenev 2025-10-23 14:35:38 -07:00 committed by TensorFlower Gardener
parent 512b85961f
commit d55e5c1d9f
4 changed files with 128 additions and 129 deletions

View File

@ -62,10 +62,10 @@ static absl::InlinedVector<XLA_CPU_KernelArg, 8> ConvertBuffersToKernelArgs(
}
template <bool num_workgroups_x_only>
class Kernel::ParallelTask {
class Kernel::Task {
public:
ParallelTask(XLA_CPU_Kernel* kernel, NumWorkGroups num_workgroups,
absl::Span<const XLA_CPU_KernelArg> args);
Task(XLA_CPU_Kernel* kernel, NumWorkGroups num_workgroups,
absl::Span<const XLA_CPU_KernelArg> args);
// Invokes a host kernel for a given task index.
absl::Status operator()(size_t task_index) const;
@ -87,7 +87,7 @@ class Kernel::ParallelTask {
};
template <bool num_workgroups_x_only>
Kernel::ParallelTask<num_workgroups_x_only>::ParallelTask(
Kernel::Task<num_workgroups_x_only>::Task(
XLA_CPU_Kernel* kernel, NumWorkGroups num_workgroups,
absl::Span<const XLA_CPU_KernelArg> args)
: kernel_(kernel),
@ -98,7 +98,7 @@ Kernel::ParallelTask<num_workgroups_x_only>::ParallelTask(
stride_y_(num_workgroups.x) {}
template <bool num_workgroups_x_only>
absl::Status Kernel::ParallelTask<num_workgroups_x_only>::operator()(
absl::Status Kernel::Task<num_workgroups_x_only>::operator()(
size_t task_index) const {
DCHECK_LT(task_index, num_tasks_) << "Task index out of range"; // Crash OK
@ -117,7 +117,7 @@ absl::Status Kernel::ParallelTask<num_workgroups_x_only>::operator()(
}
template <bool num_workgroups_x_only>
XLA_CPU_WorkGroupId Kernel::ParallelTask<num_workgroups_x_only>::Delinearize(
XLA_CPU_WorkGroupId Kernel::Task<num_workgroups_x_only>::Delinearize(
uint64_t task_index) const {
// In the most common case we parallelize only over the `x` dimension.
if constexpr (num_workgroups_x_only) {
@ -197,14 +197,12 @@ tsl::AsyncValueRef<LaunchEvent> Kernel::Launch(
std::numeric_limits<uint16_t>::max());
if (ABSL_PREDICT_TRUE(num_workgroups.y == 1 && num_workgroups.z == 1)) {
return Worker::Parallelize(
device->getPool(), num_workers, num_tasks,
ParallelTask<true>(kernel_, num_workgroups, args));
return Worker::Parallelize(device->getPool(), num_workers, num_tasks,
Task<true>(kernel_, num_workgroups, args));
}
return Worker::Parallelize(
device->getPool(), num_workers, num_tasks,
ParallelTask<false>(kernel_, num_workgroups, args));
return Worker::Parallelize(device->getPool(), num_workers, num_tasks,
Task<false>(kernel_, num_workgroups, args));
}
} // namespace xla::cpu

View File

@ -105,9 +105,9 @@ class Kernel {
}
private:
// A kernel parallel task that is used to parallelize host kernel execution.
// A kernel task that is used to parallelize host kernel execution.
template <bool num_workgroups_x_only>
class ParallelTask;
class Task;
std::unique_ptr<KernelFunction> function_;
XLA_CPU_Kernel* kernel_; // pointer to the kernel owned by `function_`

View File

@ -40,17 +40,17 @@ limitations under the License.
namespace xla::cpu {
// A work queue that partitions `num_tasks` tasks into `num_partitions`
// partitions processed by parallel workers.
// A work queue that partitions `num_work_items` work items into
// `num_partitions` partitions processed by parallel workers.
class WorkQueue {
public:
WorkQueue(size_t num_tasks, size_t num_partitions);
WorkQueue(size_t num_work_items, size_t num_partitions);
// Returns the next task in the given partition. Returns std::nullopt
// Returns the next work item in the given partition. Returns std::nullopt
// if the partition is complete.
std::optional<size_t> Pop(size_t partition_index);
// Return the partition [begin, end) task range.
// Return the partition [begin, end) work items range.
std::pair<size_t, size_t> partition_range(size_t partition_index) const;
size_t num_partitions() const { return partitions_.size(); }
@ -63,7 +63,7 @@ class WorkQueue {
struct Partition {
void Initialize(size_t begin, size_t end);
// Tracks index of the next task in the assigned partition.
// Tracks index of the next work item in the assigned partition.
ABSL_CACHELINE_ALIGNED std::atomic<size_t> index;
size_t begin;
size_t end;
@ -85,10 +85,10 @@ class WorkQueue {
ABSL_CACHELINE_ALIGNED std::atomic<size_t> num_work_stealing_workers_;
};
// Worker processes tasks from the work queue starting from the assigned
// Worker processes work items from the work queue starting from the assigned
// work partition. Once the assigned partition is complete it tries to pop
// the task from the next partition. Once the work queue is empty (the worker
// wraps around to the initial partition) it returns and empty task.
// the work item from the next partition. Once the work queue is empty (the
// worker wraps around to the initial partition) it returns and empty work item.
class Worker {
public:
Worker(size_t worker_index, WorkQueue* queue);
@ -96,23 +96,23 @@ class Worker {
std::optional<size_t> Pop();
// Schedule `num_workers` workers into the Eigen thread pool that process
// `num_tasks` parallel tasks and return an async value that becomes
// `num_work_items` parallel work items and return an async value that becomes
// available when all workers are completed.
template <typename ParallelTask>
template <typename ParallelWork>
static tsl::AsyncValueRef<tsl::Chain> Parallelize(
Eigen::ThreadPoolInterface* thread_pool, size_t num_workers,
size_t num_tasks, ParallelTask&& parallel_task);
size_t num_work_items, ParallelWork&& parallel_work);
private:
template <typename ParallelTask>
template <typename ParallelWork>
struct ParallelizeContext;
template <typename ParallelTask>
static absl::Status ExecuteInline(size_t num_tasks,
ParallelTask&& parallel_task);
template <typename ParallelWork>
static absl::Status ExecuteInline(size_t num_work_items,
ParallelWork&& parallel_work);
template <typename ParallelTask>
static void Parallelize(std::shared_ptr<ParallelizeContext<ParallelTask>> ctx,
template <typename ParallelWork>
static void Parallelize(std::shared_ptr<ParallelizeContext<ParallelWork>> ctx,
uint16_t start_index, uint16_t end_index);
size_t worker_index_;
@ -126,14 +126,14 @@ inline void WorkQueue::Partition::Initialize(size_t begin, size_t end) {
this->end = end;
}
inline WorkQueue::WorkQueue(size_t num_tasks, size_t num_partitions)
inline WorkQueue::WorkQueue(size_t num_work_items, size_t num_partitions)
: partitions_(num_partitions),
empty_(num_tasks == 0),
empty_(num_work_items == 0),
num_work_stealing_workers_(0) {
size_t partition_size = num_tasks / num_partitions;
size_t rem_tasks = num_tasks % num_partitions;
size_t partition_size = num_work_items / num_partitions;
size_t rem_work_items = num_work_items % num_partitions;
for (size_t i = 0, begin = 0, end = 0; i < num_partitions; ++i, begin = end) {
end = begin + partition_size + ((i < rem_tasks) ? 1 : 0);
end = begin + partition_size + ((i < rem_work_items) ? 1 : 0);
partitions_[i].Initialize(begin, end);
}
}
@ -148,7 +148,7 @@ inline std::optional<size_t> WorkQueue::Pop(size_t partition_index) {
return std::nullopt;
}
// Try to acquire the next task in the partition.
// Try to acquire the next work item in the partition.
size_t index = partition.index.fetch_add(1, std::memory_order_relaxed);
return ABSL_PREDICT_FALSE(index >= partition.end) ? std::nullopt
: std::make_optional(index);
@ -183,18 +183,18 @@ inline Worker::Worker(size_t worker_index, WorkQueue* queue)
queue_(queue) {}
inline std::optional<size_t> Worker::Pop() {
std::optional<size_t> task = queue_->Pop(partition_index_);
if (ABSL_PREDICT_TRUE(task)) {
return task;
std::optional<size_t> work_item = queue_->Pop(partition_index_);
if (ABSL_PREDICT_TRUE(work_item)) {
return work_item;
}
// If we didn't find a task in the initially assigned partition, notify the
// work queue that we are switching to work stealing mode.
// If we didn't find a work item in the initially assigned partition, notify
// the work queue that we are switching to work stealing mode.
if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
queue_->NotifyWorkStealingWorker();
}
while (!task.has_value() && !queue_->IsEmpty()) {
while (!work_item.has_value() && !queue_->IsEmpty()) {
// Wrap around to the first partition.
if (ABSL_PREDICT_FALSE(++partition_index_ >= queue_->num_partitions())) {
partition_index_ = 0;
@ -206,44 +206,44 @@ inline std::optional<size_t> Worker::Pop() {
break;
}
task = queue_->Pop(partition_index_);
work_item = queue_->Pop(partition_index_);
}
return task;
return work_item;
}
template <typename ParallelTask>
template <typename ParallelWork>
struct Worker::ParallelizeContext {
ParallelizeContext(Eigen::ThreadPoolInterface* thread_pool,
tsl::CountDownAsyncValueRef<tsl::Chain> count_down,
size_t num_tasks, ParallelTask&& parallel_task);
size_t num_work_items, ParallelWork&& parallel_work);
Eigen::ThreadPoolInterface* thread_pool;
tsl::CountDownAsyncValueRef<tsl::Chain> count_down;
WorkQueue work_queue;
ParallelTask parallel_task;
ParallelWork parallel_work;
};
template <typename ParallelTask>
Worker::ParallelizeContext<ParallelTask>::ParallelizeContext(
template <typename ParallelWork>
Worker::ParallelizeContext<ParallelWork>::ParallelizeContext(
Eigen::ThreadPoolInterface* thread_pool,
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t num_tasks,
ParallelTask&& parallel_task)
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t num_work_items,
ParallelWork&& parallel_work)
: thread_pool(thread_pool),
count_down(std::move(count_down)),
work_queue(num_tasks, /*num_partitions=*/this->count_down.count()),
parallel_task(std::forward<ParallelTask>(parallel_task)) {}
work_queue(num_work_items, /*num_partitions=*/this->count_down.count()),
parallel_work(std::forward<ParallelWork>(parallel_work)) {}
template <typename ParallelTask>
template <typename ParallelWork>
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void Worker::Parallelize(std::shared_ptr<ParallelizeContext<ParallelTask>> ctx,
void Worker::Parallelize(std::shared_ptr<ParallelizeContext<ParallelWork>> ctx,
uint16_t start_index, uint16_t end_index) {
DCHECK_LT(start_index, end_index) << "Invalid worker index range";
using R = std::invoke_result_t<ParallelTask, size_t>;
using R = std::invoke_result_t<ParallelWork, size_t>;
static_assert(std::is_same_v<R, absl::Status> || std::is_void_v<R>,
"Unsupported parallel task return type");
"Unsupported parallel work return type");
// Recursively split assigned workers into two halves and schedule the
// right half into the thread pool.
@ -254,7 +254,7 @@ void Worker::Parallelize(std::shared_ptr<ParallelizeContext<ParallelTask>> ctx,
}
// If we have workers in the work stealing mode, we can skip scheduling
// more tasks as existing workers will process remaining partitions. By
// more workers as existing workers will process remaining partitions. By
// doing this optimization we avoid unnecessary thread pool overheads.
size_t skip_workers =
ctx->work_queue.DecrementWorkStealingWorkers(end_index - start_index);
@ -283,54 +283,54 @@ void Worker::Parallelize(std::shared_ptr<ParallelizeContext<ParallelTask>> ctx,
// Execute the `start_index` worker in the caller thread.
Worker worker(start_index, &ctx->work_queue);
size_t num_processed_tasks = 0;
size_t num_processed_work_items = 0;
// Keep track of the first error status encountered by any of the workers.
absl::Status status;
while (std::optional<size_t> task = worker.Pop()) {
while (std::optional<size_t> work_item = worker.Pop()) {
if constexpr (std::is_same_v<R, absl::Status>) {
if (ABSL_PREDICT_TRUE(status.ok())) {
status.Update(ctx->parallel_task(*task));
status.Update(ctx->parallel_work(*work_item));
}
} else {
ctx->parallel_task(*task);
ctx->parallel_work(*work_item);
}
++num_processed_tasks;
++num_processed_work_items;
}
ctx->count_down.CountDown(num_processed_tasks, std::move(status));
ctx->count_down.CountDown(num_processed_work_items, std::move(status));
}
template <typename ParallelTask>
template <typename ParallelWork>
ABSL_ATTRIBUTE_ALWAYS_INLINE absl::Status Worker::ExecuteInline(
size_t num_tasks, ParallelTask&& parallel_task) {
using R = std::invoke_result_t<ParallelTask, size_t>;
size_t num_work_items, ParallelWork&& parallel_work) {
using R = std::invoke_result_t<ParallelWork, size_t>;
static_assert(std::is_same_v<R, absl::Status> || std::is_void_v<R>,
"Unsupported parallel task return type");
"Unsupported parallel work return type");
for (size_t i = 0; i < num_tasks; ++i) {
for (size_t i = 0; i < num_work_items; ++i) {
if constexpr (std::is_same_v<R, absl::Status>) {
absl::Status status = parallel_task(i);
absl::Status status = parallel_work(i);
if (ABSL_PREDICT_FALSE(!status.ok())) {
return status;
}
} else {
parallel_task(i);
parallel_work(i);
}
}
return absl::OkStatus();
}
template <typename ParallelTask>
template <typename ParallelWork>
ABSL_ATTRIBUTE_ALWAYS_INLINE tsl::AsyncValueRef<tsl::Chain> Worker::Parallelize(
Eigen::ThreadPoolInterface* thread_pool, size_t num_workers,
size_t num_tasks, ParallelTask&& parallel_task) {
size_t num_work_items, ParallelWork&& parallel_work) {
// Short-circuit single-threaded execution.
if (ABSL_PREDICT_FALSE(num_workers == 1)) {
if (absl::Status status =
ExecuteInline(num_tasks, std::forward<ParallelTask>(parallel_task));
if (absl::Status status = ExecuteInline(
num_work_items, std::forward<ParallelWork>(parallel_work));
ABSL_PREDICT_FALSE(!status.ok())) {
return status;
}
@ -341,16 +341,16 @@ ABSL_ATTRIBUTE_ALWAYS_INLINE tsl::AsyncValueRef<tsl::Chain> Worker::Parallelize(
if (ABSL_PREDICT_FALSE(num_workers > std::numeric_limits<uint16_t>::max())) {
num_workers = std::numeric_limits<uint16_t>::max();
}
// Ensure we don't launch more workers than tasks.
// Extra workers would be idle or cause out-of-bounds partition access.
num_workers = std::min(num_tasks, num_workers);
// Ensure we don't launch more workers than work items. Extra workers would be
// idle or cause out-of-bounds partition access.
num_workers = std::min(num_work_items, num_workers);
tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_tasks);
tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_work_items);
auto execute_event = count_down.AsRef();
auto ctx = std::make_shared<ParallelizeContext<ParallelTask>>(
thread_pool, std::move(count_down), num_tasks,
std::forward<ParallelTask>(parallel_task));
auto ctx = std::make_shared<ParallelizeContext<ParallelWork>>(
thread_pool, std::move(count_down), num_work_items,
std::forward<ParallelWork>(parallel_work));
Parallelize(std::move(ctx), 0, num_workers);

View File

@ -41,7 +41,7 @@ TEST(WorkQueueTest, WorkQueuePartitions) {
};
{
WorkQueue queue(/*num_tasks=*/2, /*num_partitions=*/4);
WorkQueue queue(/*num_work_items=*/2, /*num_partitions=*/4);
EXPECT_EQ(queue.partition_range(0), task_range(0, 1));
EXPECT_EQ(queue.partition_range(1), task_range(1, 2));
EXPECT_EQ(queue.partition_range(2), task_range(2, 2));
@ -49,7 +49,7 @@ TEST(WorkQueueTest, WorkQueuePartitions) {
}
{
WorkQueue queue(/*num_tasks=*/4, /*num_partitions=*/4);
WorkQueue queue(/*num_work_items=*/4, /*num_partitions=*/4);
EXPECT_EQ(queue.partition_range(0), task_range(0, 1));
EXPECT_EQ(queue.partition_range(1), task_range(1, 2));
EXPECT_EQ(queue.partition_range(2), task_range(2, 3));
@ -57,7 +57,7 @@ TEST(WorkQueueTest, WorkQueuePartitions) {
}
{
WorkQueue queue(/*num_tasks=*/5, /*num_partitions=*/4);
WorkQueue queue(/*num_work_items=*/5, /*num_partitions=*/4);
EXPECT_EQ(queue.partition_range(0), task_range(0, 2));
EXPECT_EQ(queue.partition_range(1), task_range(2, 3));
EXPECT_EQ(queue.partition_range(2), task_range(3, 4));
@ -65,7 +65,7 @@ TEST(WorkQueueTest, WorkQueuePartitions) {
}
{
WorkQueue queue(/*num_tasks=*/9, /*num_partitions=*/4);
WorkQueue queue(/*num_work_items=*/9, /*num_partitions=*/4);
EXPECT_EQ(queue.partition_range(0), task_range(0, 3));
EXPECT_EQ(queue.partition_range(1), task_range(3, 5));
EXPECT_EQ(queue.partition_range(2), task_range(5, 7));
@ -73,7 +73,7 @@ TEST(WorkQueueTest, WorkQueuePartitions) {
}
{
WorkQueue queue(/*num_tasks=*/14, /*num_partitions=*/4);
WorkQueue queue(/*num_work_items=*/14, /*num_partitions=*/4);
EXPECT_EQ(queue.partition_range(0), task_range(0, 4));
EXPECT_EQ(queue.partition_range(1), task_range(4, 8));
EXPECT_EQ(queue.partition_range(2), task_range(8, 11));
@ -107,17 +107,17 @@ TEST(WorkQueueTest, WorkQueue) {
for (size_t num_partitions : {1, 2, 3, 4, 5, 6, 7, 8}) {
WorkQueue queue(size, num_partitions);
std::vector<size_t> expected_tasks(size);
absl::c_iota(expected_tasks, 0);
std::vector<size_t> expected_work_items(size);
absl::c_iota(expected_work_items, 0);
std::vector<size_t> tasks;
std::vector<size_t> work_items;
for (size_t i = 0; i < num_partitions; ++i) {
while (std::optional<size_t> task = queue.Pop(i)) {
tasks.push_back(*task);
while (std::optional<size_t> work_item = queue.Pop(i)) {
work_items.push_back(*work_item);
}
}
EXPECT_EQ(tasks, expected_tasks);
EXPECT_EQ(work_items, expected_work_items);
}
}
}
@ -126,21 +126,21 @@ TEST(WorkQueueTest, Worker) {
for (size_t size : {1, 2, 4, 8, 16, 32, 64}) {
for (size_t num_partitions : {1, 2, 3, 4, 5, 6, 7, 8}) {
// We check that no matter what is the initial partition, the worker
// processes all tasks in the queue.
// processes all work items in the queue.
for (size_t i = 0; i < num_partitions; ++i) {
WorkQueue queue(size, num_partitions);
Worker worker(i, &queue);
std::vector<size_t> expected_tasks(size);
absl::c_iota(expected_tasks, 0);
std::vector<size_t> expected_work_items(size);
absl::c_iota(expected_work_items, 0);
std::vector<size_t> tasks;
while (std::optional<size_t> task = worker.Pop()) {
tasks.push_back(*task);
std::vector<size_t> work_items;
while (std::optional<size_t> work_item = worker.Pop()) {
work_items.push_back(*work_item);
}
absl::c_sort(tasks); // we pop tasks out of order
EXPECT_EQ(tasks, expected_tasks);
absl::c_sort(work_items); // we pop work_items out of order
EXPECT_EQ(work_items, expected_work_items);
}
}
}
@ -154,22 +154,22 @@ TEST(WorkQueueTest, WorkerConcurrency) {
WorkQueue queue(size, num_partitions);
// Check that we pop exactly `size` tasks.
std::atomic<size_t> num_tasks(0);
// Check that we pop exactly `size` work_items.
std::atomic<size_t> num_work_items(0);
absl::BlockingCounter counter(num_partitions);
for (size_t i = 0; i < num_partitions; ++i) {
threads.Schedule([&, i] {
Worker worker(i, &queue);
while (std::optional<size_t> task = worker.Pop()) {
++num_tasks;
while (std::optional<size_t> work_item = worker.Pop()) {
++num_work_items;
}
counter.DecrementCount();
});
}
counter.Wait();
EXPECT_EQ(num_tasks.load(), size);
EXPECT_EQ(num_work_items.load(), size);
}
TEST(WorkQueueTest, WorkerParallelize) {
@ -215,35 +215,36 @@ TEST(WorkQueueTest, WorkerParallelizeVariousWorkerTaskRatios) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 16);
struct TestCase {
size_t num_tasks;
size_t num_work_items;
size_t num_workers;
};
std::vector<TestCase> test_cases = {
{0, 1}, // Edge: no tasks
{0, 1}, // Edge: no work_items
{1, 1}, // Edge: single task, single worker
{1, 8}, // Edge: single task, many workers
{8, 1}, // Serial execution
{8, 4}, // Fewer workers than tasks
{8, 4}, // Fewer workers than work_items
{8, 8}, // Equal
{8, 16}, // More workers than tasks
{1024, 8}, // Many tasks, fewer workers
{1024, 64} // Many tasks, many workers
{8, 16}, // More workers than work_items
{1024, 8}, // Many work_items, fewer workers
{1024, 64} // Many work_items, many workers
};
for (const auto& test : test_cases) {
std::vector<size_t> data(test.num_tasks, 0);
std::vector<size_t> data(test.num_work_items, 0);
auto event = Worker::Parallelize(
threads.AsEigenThreadPool(), test.num_workers, test.num_tasks,
threads.AsEigenThreadPool(), test.num_workers, test.num_work_items,
[&](size_t task_index) { ++data[task_index]; });
tsl::BlockUntilReady(event);
// Verify that all tasks were executed once (if any exist)
std::vector<size_t> expected(test.num_tasks, 1);
EXPECT_EQ(data, expected) << "Failed for num_tasks=" << test.num_tasks
<< ", num_workers=" << test.num_workers;
// Verify that all work_items were executed once (if any exist)
std::vector<size_t> expected(test.num_work_items, 1);
EXPECT_EQ(data, expected)
<< "Failed for num_work_items=" << test.num_work_items
<< ", num_workers=" << test.num_workers;
}
}
@ -251,35 +252,35 @@ TEST(WorkQueueTest, WorkerParallelizeVariousWorkerTaskRatios) {
// Performance benchmarks.
//===----------------------------------------------------------------------===//
static void BM_PopTask(benchmark::State& state) {
static void BM_PopWorkItem(benchmark::State& state) {
std::optional<WorkQueue> queue;
std::optional<Worker> worker;
size_t n = 0;
for (auto _ : state) {
if (n++ % (1024 * 10) == 0) {
queue.emplace(/*num_tasks=*/1024 * 10, /*num_partitions=*/10);
queue.emplace(/*num_work_items=*/1024 * 10, /*num_partitions=*/10);
worker.emplace(0, &*queue);
}
worker->Pop();
}
}
BENCHMARK(BM_PopTask);
BENCHMARK(BM_PopWorkItem);
static void BM_PopTaskMultiThreaded(benchmark::State& state) {
static void BM_PopWorkItemMultiThreaded(benchmark::State& state) {
size_t num_threads = state.range(0);
tsl::thread::ThreadPool threads(tsl::Env::Default(), "benchmark",
num_threads);
tsl::thread::ThreadPool threads(tsl::Env::Default(), "bench", num_threads);
for (auto _ : state) {
absl::BlockingCounter counter(num_threads);
WorkQueue queue(/*num_tasks=*/1024 * 10, /*num_partitions=*/num_threads);
WorkQueue queue(/*num_work_items=*/1024 * 10,
/*num_partitions=*/num_threads);
for (size_t i = 0; i < num_threads; ++i) {
threads.Schedule([i, &queue, &counter] {
Worker worker(i, &queue);
while (std::optional<size_t> task = worker.Pop()) {
while (std::optional<size_t> work_item = worker.Pop()) {
}
counter.DecrementCount();
});
@ -291,7 +292,7 @@ static void BM_PopTaskMultiThreaded(benchmark::State& state) {
state.SetItemsProcessed(state.iterations() * 1024 * 10);
}
BENCHMARK(BM_PopTaskMultiThreaded)
BENCHMARK(BM_PopWorkItemMultiThreaded)
->MeasureProcessCPUTime()
->Arg(2)
->Arg(4)