mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:cpu] Use work_item instead of a task in WorkQueue/Worker API
To avoid confusion because of different kinds of tasks we have in Worker/WorkQueue and a SlinklyThreadPool in XLA use a more generic "work item" name. PiperOrigin-RevId: 823191886
This commit is contained in:
parent
512b85961f
commit
d55e5c1d9f
|
|
@ -62,10 +62,10 @@ static absl::InlinedVector<XLA_CPU_KernelArg, 8> ConvertBuffersToKernelArgs(
|
|||
}
|
||||
|
||||
template <bool num_workgroups_x_only>
|
||||
class Kernel::ParallelTask {
|
||||
class Kernel::Task {
|
||||
public:
|
||||
ParallelTask(XLA_CPU_Kernel* kernel, NumWorkGroups num_workgroups,
|
||||
absl::Span<const XLA_CPU_KernelArg> args);
|
||||
Task(XLA_CPU_Kernel* kernel, NumWorkGroups num_workgroups,
|
||||
absl::Span<const XLA_CPU_KernelArg> args);
|
||||
|
||||
// Invokes a host kernel for a given task index.
|
||||
absl::Status operator()(size_t task_index) const;
|
||||
|
|
@ -87,7 +87,7 @@ class Kernel::ParallelTask {
|
|||
};
|
||||
|
||||
template <bool num_workgroups_x_only>
|
||||
Kernel::ParallelTask<num_workgroups_x_only>::ParallelTask(
|
||||
Kernel::Task<num_workgroups_x_only>::Task(
|
||||
XLA_CPU_Kernel* kernel, NumWorkGroups num_workgroups,
|
||||
absl::Span<const XLA_CPU_KernelArg> args)
|
||||
: kernel_(kernel),
|
||||
|
|
@ -98,7 +98,7 @@ Kernel::ParallelTask<num_workgroups_x_only>::ParallelTask(
|
|||
stride_y_(num_workgroups.x) {}
|
||||
|
||||
template <bool num_workgroups_x_only>
|
||||
absl::Status Kernel::ParallelTask<num_workgroups_x_only>::operator()(
|
||||
absl::Status Kernel::Task<num_workgroups_x_only>::operator()(
|
||||
size_t task_index) const {
|
||||
DCHECK_LT(task_index, num_tasks_) << "Task index out of range"; // Crash OK
|
||||
|
||||
|
|
@ -117,7 +117,7 @@ absl::Status Kernel::ParallelTask<num_workgroups_x_only>::operator()(
|
|||
}
|
||||
|
||||
template <bool num_workgroups_x_only>
|
||||
XLA_CPU_WorkGroupId Kernel::ParallelTask<num_workgroups_x_only>::Delinearize(
|
||||
XLA_CPU_WorkGroupId Kernel::Task<num_workgroups_x_only>::Delinearize(
|
||||
uint64_t task_index) const {
|
||||
// In the most common case we parallelize only over the `x` dimension.
|
||||
if constexpr (num_workgroups_x_only) {
|
||||
|
|
@ -197,14 +197,12 @@ tsl::AsyncValueRef<LaunchEvent> Kernel::Launch(
|
|||
std::numeric_limits<uint16_t>::max());
|
||||
|
||||
if (ABSL_PREDICT_TRUE(num_workgroups.y == 1 && num_workgroups.z == 1)) {
|
||||
return Worker::Parallelize(
|
||||
device->getPool(), num_workers, num_tasks,
|
||||
ParallelTask<true>(kernel_, num_workgroups, args));
|
||||
return Worker::Parallelize(device->getPool(), num_workers, num_tasks,
|
||||
Task<true>(kernel_, num_workgroups, args));
|
||||
}
|
||||
|
||||
return Worker::Parallelize(
|
||||
device->getPool(), num_workers, num_tasks,
|
||||
ParallelTask<false>(kernel_, num_workgroups, args));
|
||||
return Worker::Parallelize(device->getPool(), num_workers, num_tasks,
|
||||
Task<false>(kernel_, num_workgroups, args));
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
|
|
@ -105,9 +105,9 @@ class Kernel {
|
|||
}
|
||||
|
||||
private:
|
||||
// A kernel parallel task that is used to parallelize host kernel execution.
|
||||
// A kernel task that is used to parallelize host kernel execution.
|
||||
template <bool num_workgroups_x_only>
|
||||
class ParallelTask;
|
||||
class Task;
|
||||
|
||||
std::unique_ptr<KernelFunction> function_;
|
||||
XLA_CPU_Kernel* kernel_; // pointer to the kernel owned by `function_`
|
||||
|
|
|
|||
|
|
@ -40,17 +40,17 @@ limitations under the License.
|
|||
|
||||
namespace xla::cpu {
|
||||
|
||||
// A work queue that partitions `num_tasks` tasks into `num_partitions`
|
||||
// partitions processed by parallel workers.
|
||||
// A work queue that partitions `num_work_items` work items into
|
||||
// `num_partitions` partitions processed by parallel workers.
|
||||
class WorkQueue {
|
||||
public:
|
||||
WorkQueue(size_t num_tasks, size_t num_partitions);
|
||||
WorkQueue(size_t num_work_items, size_t num_partitions);
|
||||
|
||||
// Returns the next task in the given partition. Returns std::nullopt
|
||||
// Returns the next work item in the given partition. Returns std::nullopt
|
||||
// if the partition is complete.
|
||||
std::optional<size_t> Pop(size_t partition_index);
|
||||
|
||||
// Return the partition [begin, end) task range.
|
||||
// Return the partition [begin, end) work items range.
|
||||
std::pair<size_t, size_t> partition_range(size_t partition_index) const;
|
||||
|
||||
size_t num_partitions() const { return partitions_.size(); }
|
||||
|
|
@ -63,7 +63,7 @@ class WorkQueue {
|
|||
struct Partition {
|
||||
void Initialize(size_t begin, size_t end);
|
||||
|
||||
// Tracks index of the next task in the assigned partition.
|
||||
// Tracks index of the next work item in the assigned partition.
|
||||
ABSL_CACHELINE_ALIGNED std::atomic<size_t> index;
|
||||
size_t begin;
|
||||
size_t end;
|
||||
|
|
@ -85,10 +85,10 @@ class WorkQueue {
|
|||
ABSL_CACHELINE_ALIGNED std::atomic<size_t> num_work_stealing_workers_;
|
||||
};
|
||||
|
||||
// Worker processes tasks from the work queue starting from the assigned
|
||||
// Worker processes work items from the work queue starting from the assigned
|
||||
// work partition. Once the assigned partition is complete it tries to pop
|
||||
// the task from the next partition. Once the work queue is empty (the worker
|
||||
// wraps around to the initial partition) it returns and empty task.
|
||||
// the work item from the next partition. Once the work queue is empty (the
|
||||
// worker wraps around to the initial partition) it returns and empty work item.
|
||||
class Worker {
|
||||
public:
|
||||
Worker(size_t worker_index, WorkQueue* queue);
|
||||
|
|
@ -96,23 +96,23 @@ class Worker {
|
|||
std::optional<size_t> Pop();
|
||||
|
||||
// Schedule `num_workers` workers into the Eigen thread pool that process
|
||||
// `num_tasks` parallel tasks and return an async value that becomes
|
||||
// `num_work_items` parallel work items and return an async value that becomes
|
||||
// available when all workers are completed.
|
||||
template <typename ParallelTask>
|
||||
template <typename ParallelWork>
|
||||
static tsl::AsyncValueRef<tsl::Chain> Parallelize(
|
||||
Eigen::ThreadPoolInterface* thread_pool, size_t num_workers,
|
||||
size_t num_tasks, ParallelTask&& parallel_task);
|
||||
size_t num_work_items, ParallelWork&& parallel_work);
|
||||
|
||||
private:
|
||||
template <typename ParallelTask>
|
||||
template <typename ParallelWork>
|
||||
struct ParallelizeContext;
|
||||
|
||||
template <typename ParallelTask>
|
||||
static absl::Status ExecuteInline(size_t num_tasks,
|
||||
ParallelTask&& parallel_task);
|
||||
template <typename ParallelWork>
|
||||
static absl::Status ExecuteInline(size_t num_work_items,
|
||||
ParallelWork&& parallel_work);
|
||||
|
||||
template <typename ParallelTask>
|
||||
static void Parallelize(std::shared_ptr<ParallelizeContext<ParallelTask>> ctx,
|
||||
template <typename ParallelWork>
|
||||
static void Parallelize(std::shared_ptr<ParallelizeContext<ParallelWork>> ctx,
|
||||
uint16_t start_index, uint16_t end_index);
|
||||
|
||||
size_t worker_index_;
|
||||
|
|
@ -126,14 +126,14 @@ inline void WorkQueue::Partition::Initialize(size_t begin, size_t end) {
|
|||
this->end = end;
|
||||
}
|
||||
|
||||
inline WorkQueue::WorkQueue(size_t num_tasks, size_t num_partitions)
|
||||
inline WorkQueue::WorkQueue(size_t num_work_items, size_t num_partitions)
|
||||
: partitions_(num_partitions),
|
||||
empty_(num_tasks == 0),
|
||||
empty_(num_work_items == 0),
|
||||
num_work_stealing_workers_(0) {
|
||||
size_t partition_size = num_tasks / num_partitions;
|
||||
size_t rem_tasks = num_tasks % num_partitions;
|
||||
size_t partition_size = num_work_items / num_partitions;
|
||||
size_t rem_work_items = num_work_items % num_partitions;
|
||||
for (size_t i = 0, begin = 0, end = 0; i < num_partitions; ++i, begin = end) {
|
||||
end = begin + partition_size + ((i < rem_tasks) ? 1 : 0);
|
||||
end = begin + partition_size + ((i < rem_work_items) ? 1 : 0);
|
||||
partitions_[i].Initialize(begin, end);
|
||||
}
|
||||
}
|
||||
|
|
@ -148,7 +148,7 @@ inline std::optional<size_t> WorkQueue::Pop(size_t partition_index) {
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Try to acquire the next task in the partition.
|
||||
// Try to acquire the next work item in the partition.
|
||||
size_t index = partition.index.fetch_add(1, std::memory_order_relaxed);
|
||||
return ABSL_PREDICT_FALSE(index >= partition.end) ? std::nullopt
|
||||
: std::make_optional(index);
|
||||
|
|
@ -183,18 +183,18 @@ inline Worker::Worker(size_t worker_index, WorkQueue* queue)
|
|||
queue_(queue) {}
|
||||
|
||||
inline std::optional<size_t> Worker::Pop() {
|
||||
std::optional<size_t> task = queue_->Pop(partition_index_);
|
||||
if (ABSL_PREDICT_TRUE(task)) {
|
||||
return task;
|
||||
std::optional<size_t> work_item = queue_->Pop(partition_index_);
|
||||
if (ABSL_PREDICT_TRUE(work_item)) {
|
||||
return work_item;
|
||||
}
|
||||
|
||||
// If we didn't find a task in the initially assigned partition, notify the
|
||||
// work queue that we are switching to work stealing mode.
|
||||
// If we didn't find a work item in the initially assigned partition, notify
|
||||
// the work queue that we are switching to work stealing mode.
|
||||
if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
|
||||
queue_->NotifyWorkStealingWorker();
|
||||
}
|
||||
|
||||
while (!task.has_value() && !queue_->IsEmpty()) {
|
||||
while (!work_item.has_value() && !queue_->IsEmpty()) {
|
||||
// Wrap around to the first partition.
|
||||
if (ABSL_PREDICT_FALSE(++partition_index_ >= queue_->num_partitions())) {
|
||||
partition_index_ = 0;
|
||||
|
|
@ -206,44 +206,44 @@ inline std::optional<size_t> Worker::Pop() {
|
|||
break;
|
||||
}
|
||||
|
||||
task = queue_->Pop(partition_index_);
|
||||
work_item = queue_->Pop(partition_index_);
|
||||
}
|
||||
|
||||
return task;
|
||||
return work_item;
|
||||
}
|
||||
|
||||
template <typename ParallelTask>
|
||||
template <typename ParallelWork>
|
||||
struct Worker::ParallelizeContext {
|
||||
ParallelizeContext(Eigen::ThreadPoolInterface* thread_pool,
|
||||
tsl::CountDownAsyncValueRef<tsl::Chain> count_down,
|
||||
size_t num_tasks, ParallelTask&& parallel_task);
|
||||
size_t num_work_items, ParallelWork&& parallel_work);
|
||||
|
||||
Eigen::ThreadPoolInterface* thread_pool;
|
||||
tsl::CountDownAsyncValueRef<tsl::Chain> count_down;
|
||||
|
||||
WorkQueue work_queue;
|
||||
ParallelTask parallel_task;
|
||||
ParallelWork parallel_work;
|
||||
};
|
||||
|
||||
template <typename ParallelTask>
|
||||
Worker::ParallelizeContext<ParallelTask>::ParallelizeContext(
|
||||
template <typename ParallelWork>
|
||||
Worker::ParallelizeContext<ParallelWork>::ParallelizeContext(
|
||||
Eigen::ThreadPoolInterface* thread_pool,
|
||||
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t num_tasks,
|
||||
ParallelTask&& parallel_task)
|
||||
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t num_work_items,
|
||||
ParallelWork&& parallel_work)
|
||||
: thread_pool(thread_pool),
|
||||
count_down(std::move(count_down)),
|
||||
work_queue(num_tasks, /*num_partitions=*/this->count_down.count()),
|
||||
parallel_task(std::forward<ParallelTask>(parallel_task)) {}
|
||||
work_queue(num_work_items, /*num_partitions=*/this->count_down.count()),
|
||||
parallel_work(std::forward<ParallelWork>(parallel_work)) {}
|
||||
|
||||
template <typename ParallelTask>
|
||||
template <typename ParallelWork>
|
||||
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
|
||||
void Worker::Parallelize(std::shared_ptr<ParallelizeContext<ParallelTask>> ctx,
|
||||
void Worker::Parallelize(std::shared_ptr<ParallelizeContext<ParallelWork>> ctx,
|
||||
uint16_t start_index, uint16_t end_index) {
|
||||
DCHECK_LT(start_index, end_index) << "Invalid worker index range";
|
||||
|
||||
using R = std::invoke_result_t<ParallelTask, size_t>;
|
||||
using R = std::invoke_result_t<ParallelWork, size_t>;
|
||||
static_assert(std::is_same_v<R, absl::Status> || std::is_void_v<R>,
|
||||
"Unsupported parallel task return type");
|
||||
"Unsupported parallel work return type");
|
||||
|
||||
// Recursively split assigned workers into two halves and schedule the
|
||||
// right half into the thread pool.
|
||||
|
|
@ -254,7 +254,7 @@ void Worker::Parallelize(std::shared_ptr<ParallelizeContext<ParallelTask>> ctx,
|
|||
}
|
||||
|
||||
// If we have workers in the work stealing mode, we can skip scheduling
|
||||
// more tasks as existing workers will process remaining partitions. By
|
||||
// more workers as existing workers will process remaining partitions. By
|
||||
// doing this optimization we avoid unnecessary thread pool overheads.
|
||||
size_t skip_workers =
|
||||
ctx->work_queue.DecrementWorkStealingWorkers(end_index - start_index);
|
||||
|
|
@ -283,54 +283,54 @@ void Worker::Parallelize(std::shared_ptr<ParallelizeContext<ParallelTask>> ctx,
|
|||
|
||||
// Execute the `start_index` worker in the caller thread.
|
||||
Worker worker(start_index, &ctx->work_queue);
|
||||
size_t num_processed_tasks = 0;
|
||||
size_t num_processed_work_items = 0;
|
||||
|
||||
// Keep track of the first error status encountered by any of the workers.
|
||||
absl::Status status;
|
||||
|
||||
while (std::optional<size_t> task = worker.Pop()) {
|
||||
while (std::optional<size_t> work_item = worker.Pop()) {
|
||||
if constexpr (std::is_same_v<R, absl::Status>) {
|
||||
if (ABSL_PREDICT_TRUE(status.ok())) {
|
||||
status.Update(ctx->parallel_task(*task));
|
||||
status.Update(ctx->parallel_work(*work_item));
|
||||
}
|
||||
} else {
|
||||
ctx->parallel_task(*task);
|
||||
ctx->parallel_work(*work_item);
|
||||
}
|
||||
++num_processed_tasks;
|
||||
++num_processed_work_items;
|
||||
}
|
||||
|
||||
ctx->count_down.CountDown(num_processed_tasks, std::move(status));
|
||||
ctx->count_down.CountDown(num_processed_work_items, std::move(status));
|
||||
}
|
||||
|
||||
template <typename ParallelTask>
|
||||
template <typename ParallelWork>
|
||||
ABSL_ATTRIBUTE_ALWAYS_INLINE absl::Status Worker::ExecuteInline(
|
||||
size_t num_tasks, ParallelTask&& parallel_task) {
|
||||
using R = std::invoke_result_t<ParallelTask, size_t>;
|
||||
size_t num_work_items, ParallelWork&& parallel_work) {
|
||||
using R = std::invoke_result_t<ParallelWork, size_t>;
|
||||
static_assert(std::is_same_v<R, absl::Status> || std::is_void_v<R>,
|
||||
"Unsupported parallel task return type");
|
||||
"Unsupported parallel work return type");
|
||||
|
||||
for (size_t i = 0; i < num_tasks; ++i) {
|
||||
for (size_t i = 0; i < num_work_items; ++i) {
|
||||
if constexpr (std::is_same_v<R, absl::Status>) {
|
||||
absl::Status status = parallel_task(i);
|
||||
absl::Status status = parallel_work(i);
|
||||
if (ABSL_PREDICT_FALSE(!status.ok())) {
|
||||
return status;
|
||||
}
|
||||
} else {
|
||||
parallel_task(i);
|
||||
parallel_work(i);
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename ParallelTask>
|
||||
template <typename ParallelWork>
|
||||
ABSL_ATTRIBUTE_ALWAYS_INLINE tsl::AsyncValueRef<tsl::Chain> Worker::Parallelize(
|
||||
Eigen::ThreadPoolInterface* thread_pool, size_t num_workers,
|
||||
size_t num_tasks, ParallelTask&& parallel_task) {
|
||||
size_t num_work_items, ParallelWork&& parallel_work) {
|
||||
// Short-circuit single-threaded execution.
|
||||
if (ABSL_PREDICT_FALSE(num_workers == 1)) {
|
||||
if (absl::Status status =
|
||||
ExecuteInline(num_tasks, std::forward<ParallelTask>(parallel_task));
|
||||
if (absl::Status status = ExecuteInline(
|
||||
num_work_items, std::forward<ParallelWork>(parallel_work));
|
||||
ABSL_PREDICT_FALSE(!status.ok())) {
|
||||
return status;
|
||||
}
|
||||
|
|
@ -341,16 +341,16 @@ ABSL_ATTRIBUTE_ALWAYS_INLINE tsl::AsyncValueRef<tsl::Chain> Worker::Parallelize(
|
|||
if (ABSL_PREDICT_FALSE(num_workers > std::numeric_limits<uint16_t>::max())) {
|
||||
num_workers = std::numeric_limits<uint16_t>::max();
|
||||
}
|
||||
// Ensure we don't launch more workers than tasks.
|
||||
// Extra workers would be idle or cause out-of-bounds partition access.
|
||||
num_workers = std::min(num_tasks, num_workers);
|
||||
// Ensure we don't launch more workers than work items. Extra workers would be
|
||||
// idle or cause out-of-bounds partition access.
|
||||
num_workers = std::min(num_work_items, num_workers);
|
||||
|
||||
tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_tasks);
|
||||
tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_work_items);
|
||||
auto execute_event = count_down.AsRef();
|
||||
|
||||
auto ctx = std::make_shared<ParallelizeContext<ParallelTask>>(
|
||||
thread_pool, std::move(count_down), num_tasks,
|
||||
std::forward<ParallelTask>(parallel_task));
|
||||
auto ctx = std::make_shared<ParallelizeContext<ParallelWork>>(
|
||||
thread_pool, std::move(count_down), num_work_items,
|
||||
std::forward<ParallelWork>(parallel_work));
|
||||
|
||||
Parallelize(std::move(ctx), 0, num_workers);
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ TEST(WorkQueueTest, WorkQueuePartitions) {
|
|||
};
|
||||
|
||||
{
|
||||
WorkQueue queue(/*num_tasks=*/2, /*num_partitions=*/4);
|
||||
WorkQueue queue(/*num_work_items=*/2, /*num_partitions=*/4);
|
||||
EXPECT_EQ(queue.partition_range(0), task_range(0, 1));
|
||||
EXPECT_EQ(queue.partition_range(1), task_range(1, 2));
|
||||
EXPECT_EQ(queue.partition_range(2), task_range(2, 2));
|
||||
|
|
@ -49,7 +49,7 @@ TEST(WorkQueueTest, WorkQueuePartitions) {
|
|||
}
|
||||
|
||||
{
|
||||
WorkQueue queue(/*num_tasks=*/4, /*num_partitions=*/4);
|
||||
WorkQueue queue(/*num_work_items=*/4, /*num_partitions=*/4);
|
||||
EXPECT_EQ(queue.partition_range(0), task_range(0, 1));
|
||||
EXPECT_EQ(queue.partition_range(1), task_range(1, 2));
|
||||
EXPECT_EQ(queue.partition_range(2), task_range(2, 3));
|
||||
|
|
@ -57,7 +57,7 @@ TEST(WorkQueueTest, WorkQueuePartitions) {
|
|||
}
|
||||
|
||||
{
|
||||
WorkQueue queue(/*num_tasks=*/5, /*num_partitions=*/4);
|
||||
WorkQueue queue(/*num_work_items=*/5, /*num_partitions=*/4);
|
||||
EXPECT_EQ(queue.partition_range(0), task_range(0, 2));
|
||||
EXPECT_EQ(queue.partition_range(1), task_range(2, 3));
|
||||
EXPECT_EQ(queue.partition_range(2), task_range(3, 4));
|
||||
|
|
@ -65,7 +65,7 @@ TEST(WorkQueueTest, WorkQueuePartitions) {
|
|||
}
|
||||
|
||||
{
|
||||
WorkQueue queue(/*num_tasks=*/9, /*num_partitions=*/4);
|
||||
WorkQueue queue(/*num_work_items=*/9, /*num_partitions=*/4);
|
||||
EXPECT_EQ(queue.partition_range(0), task_range(0, 3));
|
||||
EXPECT_EQ(queue.partition_range(1), task_range(3, 5));
|
||||
EXPECT_EQ(queue.partition_range(2), task_range(5, 7));
|
||||
|
|
@ -73,7 +73,7 @@ TEST(WorkQueueTest, WorkQueuePartitions) {
|
|||
}
|
||||
|
||||
{
|
||||
WorkQueue queue(/*num_tasks=*/14, /*num_partitions=*/4);
|
||||
WorkQueue queue(/*num_work_items=*/14, /*num_partitions=*/4);
|
||||
EXPECT_EQ(queue.partition_range(0), task_range(0, 4));
|
||||
EXPECT_EQ(queue.partition_range(1), task_range(4, 8));
|
||||
EXPECT_EQ(queue.partition_range(2), task_range(8, 11));
|
||||
|
|
@ -107,17 +107,17 @@ TEST(WorkQueueTest, WorkQueue) {
|
|||
for (size_t num_partitions : {1, 2, 3, 4, 5, 6, 7, 8}) {
|
||||
WorkQueue queue(size, num_partitions);
|
||||
|
||||
std::vector<size_t> expected_tasks(size);
|
||||
absl::c_iota(expected_tasks, 0);
|
||||
std::vector<size_t> expected_work_items(size);
|
||||
absl::c_iota(expected_work_items, 0);
|
||||
|
||||
std::vector<size_t> tasks;
|
||||
std::vector<size_t> work_items;
|
||||
for (size_t i = 0; i < num_partitions; ++i) {
|
||||
while (std::optional<size_t> task = queue.Pop(i)) {
|
||||
tasks.push_back(*task);
|
||||
while (std::optional<size_t> work_item = queue.Pop(i)) {
|
||||
work_items.push_back(*work_item);
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(tasks, expected_tasks);
|
||||
EXPECT_EQ(work_items, expected_work_items);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -126,21 +126,21 @@ TEST(WorkQueueTest, Worker) {
|
|||
for (size_t size : {1, 2, 4, 8, 16, 32, 64}) {
|
||||
for (size_t num_partitions : {1, 2, 3, 4, 5, 6, 7, 8}) {
|
||||
// We check that no matter what is the initial partition, the worker
|
||||
// processes all tasks in the queue.
|
||||
// processes all work items in the queue.
|
||||
for (size_t i = 0; i < num_partitions; ++i) {
|
||||
WorkQueue queue(size, num_partitions);
|
||||
Worker worker(i, &queue);
|
||||
|
||||
std::vector<size_t> expected_tasks(size);
|
||||
absl::c_iota(expected_tasks, 0);
|
||||
std::vector<size_t> expected_work_items(size);
|
||||
absl::c_iota(expected_work_items, 0);
|
||||
|
||||
std::vector<size_t> tasks;
|
||||
while (std::optional<size_t> task = worker.Pop()) {
|
||||
tasks.push_back(*task);
|
||||
std::vector<size_t> work_items;
|
||||
while (std::optional<size_t> work_item = worker.Pop()) {
|
||||
work_items.push_back(*work_item);
|
||||
}
|
||||
|
||||
absl::c_sort(tasks); // we pop tasks out of order
|
||||
EXPECT_EQ(tasks, expected_tasks);
|
||||
absl::c_sort(work_items); // we pop work_items out of order
|
||||
EXPECT_EQ(work_items, expected_work_items);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -154,22 +154,22 @@ TEST(WorkQueueTest, WorkerConcurrency) {
|
|||
|
||||
WorkQueue queue(size, num_partitions);
|
||||
|
||||
// Check that we pop exactly `size` tasks.
|
||||
std::atomic<size_t> num_tasks(0);
|
||||
// Check that we pop exactly `size` work_items.
|
||||
std::atomic<size_t> num_work_items(0);
|
||||
|
||||
absl::BlockingCounter counter(num_partitions);
|
||||
for (size_t i = 0; i < num_partitions; ++i) {
|
||||
threads.Schedule([&, i] {
|
||||
Worker worker(i, &queue);
|
||||
while (std::optional<size_t> task = worker.Pop()) {
|
||||
++num_tasks;
|
||||
while (std::optional<size_t> work_item = worker.Pop()) {
|
||||
++num_work_items;
|
||||
}
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
|
||||
counter.Wait();
|
||||
EXPECT_EQ(num_tasks.load(), size);
|
||||
EXPECT_EQ(num_work_items.load(), size);
|
||||
}
|
||||
|
||||
TEST(WorkQueueTest, WorkerParallelize) {
|
||||
|
|
@ -215,35 +215,36 @@ TEST(WorkQueueTest, WorkerParallelizeVariousWorkerTaskRatios) {
|
|||
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 16);
|
||||
|
||||
struct TestCase {
|
||||
size_t num_tasks;
|
||||
size_t num_work_items;
|
||||
size_t num_workers;
|
||||
};
|
||||
|
||||
std::vector<TestCase> test_cases = {
|
||||
{0, 1}, // Edge: no tasks
|
||||
{0, 1}, // Edge: no work_items
|
||||
{1, 1}, // Edge: single task, single worker
|
||||
{1, 8}, // Edge: single task, many workers
|
||||
{8, 1}, // Serial execution
|
||||
{8, 4}, // Fewer workers than tasks
|
||||
{8, 4}, // Fewer workers than work_items
|
||||
{8, 8}, // Equal
|
||||
{8, 16}, // More workers than tasks
|
||||
{1024, 8}, // Many tasks, fewer workers
|
||||
{1024, 64} // Many tasks, many workers
|
||||
{8, 16}, // More workers than work_items
|
||||
{1024, 8}, // Many work_items, fewer workers
|
||||
{1024, 64} // Many work_items, many workers
|
||||
};
|
||||
|
||||
for (const auto& test : test_cases) {
|
||||
std::vector<size_t> data(test.num_tasks, 0);
|
||||
std::vector<size_t> data(test.num_work_items, 0);
|
||||
|
||||
auto event = Worker::Parallelize(
|
||||
threads.AsEigenThreadPool(), test.num_workers, test.num_tasks,
|
||||
threads.AsEigenThreadPool(), test.num_workers, test.num_work_items,
|
||||
[&](size_t task_index) { ++data[task_index]; });
|
||||
|
||||
tsl::BlockUntilReady(event);
|
||||
|
||||
// Verify that all tasks were executed once (if any exist)
|
||||
std::vector<size_t> expected(test.num_tasks, 1);
|
||||
EXPECT_EQ(data, expected) << "Failed for num_tasks=" << test.num_tasks
|
||||
<< ", num_workers=" << test.num_workers;
|
||||
// Verify that all work_items were executed once (if any exist)
|
||||
std::vector<size_t> expected(test.num_work_items, 1);
|
||||
EXPECT_EQ(data, expected)
|
||||
<< "Failed for num_work_items=" << test.num_work_items
|
||||
<< ", num_workers=" << test.num_workers;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -251,35 +252,35 @@ TEST(WorkQueueTest, WorkerParallelizeVariousWorkerTaskRatios) {
|
|||
// Performance benchmarks.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void BM_PopTask(benchmark::State& state) {
|
||||
static void BM_PopWorkItem(benchmark::State& state) {
|
||||
std::optional<WorkQueue> queue;
|
||||
std::optional<Worker> worker;
|
||||
|
||||
size_t n = 0;
|
||||
for (auto _ : state) {
|
||||
if (n++ % (1024 * 10) == 0) {
|
||||
queue.emplace(/*num_tasks=*/1024 * 10, /*num_partitions=*/10);
|
||||
queue.emplace(/*num_work_items=*/1024 * 10, /*num_partitions=*/10);
|
||||
worker.emplace(0, &*queue);
|
||||
}
|
||||
worker->Pop();
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_PopTask);
|
||||
BENCHMARK(BM_PopWorkItem);
|
||||
|
||||
static void BM_PopTaskMultiThreaded(benchmark::State& state) {
|
||||
static void BM_PopWorkItemMultiThreaded(benchmark::State& state) {
|
||||
size_t num_threads = state.range(0);
|
||||
tsl::thread::ThreadPool threads(tsl::Env::Default(), "benchmark",
|
||||
num_threads);
|
||||
tsl::thread::ThreadPool threads(tsl::Env::Default(), "bench", num_threads);
|
||||
|
||||
for (auto _ : state) {
|
||||
absl::BlockingCounter counter(num_threads);
|
||||
WorkQueue queue(/*num_tasks=*/1024 * 10, /*num_partitions=*/num_threads);
|
||||
WorkQueue queue(/*num_work_items=*/1024 * 10,
|
||||
/*num_partitions=*/num_threads);
|
||||
|
||||
for (size_t i = 0; i < num_threads; ++i) {
|
||||
threads.Schedule([i, &queue, &counter] {
|
||||
Worker worker(i, &queue);
|
||||
while (std::optional<size_t> task = worker.Pop()) {
|
||||
while (std::optional<size_t> work_item = worker.Pop()) {
|
||||
}
|
||||
counter.DecrementCount();
|
||||
});
|
||||
|
|
@ -291,7 +292,7 @@ static void BM_PopTaskMultiThreaded(benchmark::State& state) {
|
|||
state.SetItemsProcessed(state.iterations() * 1024 * 10);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_PopTaskMultiThreaded)
|
||||
BENCHMARK(BM_PopWorkItemMultiThreaded)
|
||||
->MeasureProcessCPUTime()
|
||||
->Arg(2)
|
||||
->Arg(4)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user