[xla:cpu:ynn] Do not track work stealing workers

```
name                                                               cpu/op         cpu/op      vs base
BM_ParallelFor/8/1/process_time   [#threads=8, #threadpools=1  ]    5.470m ±  5%   5.095m ± 3%  -6.87% (p=0.000 n=80)
BM_ParallelFor/8/2/process_time   [#threads=8, #threadpools=2  ]    2.857m ±  1%   2.595m ± 2%  -9.15% (n=80)
BM_ParallelFor/8/4/process_time   [#threads=8, #threadpools=4  ]    1.447m ± 10%   1.328m ± 1%  -8.23% (p=0.000 n=80)
BM_ParallelFor/8/8/process_time   [#threads=8, #threadpools=8  ]   1058.1µ ± 20%   974.5µ ± 1%  -7.90% (p=0.000 n=80)
BM_ParallelFor/8/16/process_time  [#threads=8, #threadpools=16 ]    741.5µ ± 26%   705.8µ ± 1%  -4.81% (p=0.000 n=80)
BM_ParallelFor/16/1/process_time  [#threads=16, #threadpools=1 ]    9.796m ± 29%   9.972m ± 2%       ~ (p=0.312 n=80)
BM_ParallelFor/16/2/process_time  [#threads=16, #threadpools=2 ]    7.871m ± 28%   7.706m ± 1%  -2.10% (p=0.030 n=80)
BM_ParallelFor/16/4/process_time  [#threads=16, #threadpools=4 ]    4.330m ±  2%   4.157m ± 1%  -3.99% (p=0.000 n=80)
BM_ParallelFor/16/8/process_time  [#threads=16, #threadpools=8 ]    2.678m ±  2%   2.638m ± 1%  -1.49% (p=0.014 n=80)
BM_ParallelFor/16/16/process_time [#threads=16, #threadpools=16]    1.791m ±  1%   1.807m ± 1%       ~ (p=0.325 n=80)
BM_ParallelFor/32/1/process_time  [#threads=32, #threadpools=1 ]    15.33m ±  1%   15.41m ± 1%       ~ (p=0.215 n=80)
BM_ParallelFor/32/2/process_time  [#threads=32, #threadpools=2 ]    13.99m ±  1%   13.80m ± 2%       ~ (p=0.400 n=80)
BM_ParallelFor/32/4/process_time  [#threads=32, #threadpools=4 ]    9.415m ±  1%   9.172m ± 1%  -2.58% (p=0.000 n=80)
BM_ParallelFor/32/8/process_time  [#threads=32, #threadpools=8 ]    5.759m ±  1%   5.647m ± 1%  -1.95% (p=0.004 n=80)
BM_ParallelFor/32/16/process_time [#threads=32, #threadpools=16]    3.932m ±  1%   3.864m ± 1%  -1.72% (p=0.006 n=80)
geomean                                                            4.051m         3.916m       -3.32%

name                                                               time/op        time/op     vs base
BM_ParallelFor/8/1/process_time   [#threads=8, #threadpools=1  ]    651.2µ ±  3%   600.3µ ± 4%  -7.80% (p=0.000 n=80)
BM_ParallelFor/8/2/process_time   [#threads=8, #threadpools=2  ]    329.4µ ±  0%   298.6µ ± 2%  -9.35% (n=80)
BM_ParallelFor/8/4/process_time   [#threads=8, #threadpools=4  ]    169.3µ ± 12%   155.7µ ± 1%  -8.05% (p=0.000 n=80)
BM_ParallelFor/8/8/process_time   [#threads=8, #threadpools=8  ]    125.8µ ± 21%   115.7µ ± 1%  -8.08% (p=0.000 n=80)
BM_ParallelFor/8/16/process_time  [#threads=8, #threadpools=16 ]    95.41µ ± 24%   89.56µ ± 1%  -6.13% (p=0.000 n=80)
BM_ParallelFor/16/1/process_time  [#threads=16, #threadpools=1 ]   1015.8µ ±  1%   952.0µ ± 1%  -6.29% (n=80)
BM_ParallelFor/16/2/process_time  [#threads=16, #threadpools=2 ]    556.5µ ±  1%   522.6µ ± 1%  -6.09% (n=80)
BM_ParallelFor/16/4/process_time  [#threads=16, #threadpools=4 ]    289.7µ ±  2%   274.4µ ± 1%  -5.30% (p=0.000 n=80)
BM_ParallelFor/16/8/process_time  [#threads=16, #threadpools=8 ]    178.8µ ±  2%   174.1µ ± 1%  -2.59% (p=0.000 n=80)
BM_ParallelFor/16/16/process_time [#threads=16, #threadpools=16]    123.9µ ±  2%   123.0µ ± 1%       ~ (p=0.098 n=80)
BM_ParallelFor/32/1/process_time  [#threads=32, #threadpools=1 ]    1.526m ±  3%   1.433m ± 3%  -6.07% (p=0.000 n=80)
BM_ParallelFor/32/2/process_time  [#threads=32, #threadpools=2 ]    835.2µ ±  2%   783.5µ ± 2%  -6.19% (p=0.000 n=80)
BM_ParallelFor/32/4/process_time  [#threads=32, #threadpools=4 ]    471.6µ ±  2%   455.1µ ± 1%  -3.52% (p=0.000 n=80)
BM_ParallelFor/32/8/process_time  [#threads=32, #threadpools=8 ]    296.1µ ±  2%   287.0µ ± 2%  -3.08% (p=0.000 n=80)
BM_ParallelFor/32/16/process_time [#threads=32, #threadpools=16]    215.0µ ±  2%   211.6µ ± 1%  -1.59% (p=0.018 n=80)
geomean                                                            330.2µ         312.3µ       -5.42%
```

PiperOrigin-RevId: 824259124
This commit is contained in:
Eugene Zhulenev 2025-10-26 14:52:23 -07:00 committed by TensorFlower Gardener
parent e65144c31f
commit 5edcd28152
3 changed files with 23 additions and 12 deletions

View File

@ -93,7 +93,11 @@ class Worker {
public:
Worker(size_t worker_index, WorkQueue* queue);
std::optional<size_t> Pop();
// Pops a work item from the work queue. If `notify_work_stealing` is true,
// the worker will notify the work queue when it switches to the work
// stealing mode. Worker parallelization has an optimization to avoid
// scheduling more workers if there are workers in the work stealing mode.
std::optional<size_t> Pop(bool notify_work_stealing = true);
// Schedule `num_workers` workers into the Eigen thread pool that process
// `num_work_items` parallel work items and return an async value that becomes
@ -182,7 +186,7 @@ inline Worker::Worker(size_t worker_index, WorkQueue* queue)
partition_index_(worker_index),
queue_(queue) {}
inline std::optional<size_t> Worker::Pop() {
inline std::optional<size_t> Worker::Pop(bool notify_work_stealing) {
std::optional<size_t> work_item = queue_->Pop(partition_index_);
if (ABSL_PREDICT_TRUE(work_item)) {
return work_item;
@ -190,7 +194,8 @@ inline std::optional<size_t> Worker::Pop() {
// If we didn't find a work item in the initially assigned partition, notify
// the work queue that we are switching to work stealing mode.
if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
if (ABSL_PREDICT_FALSE(notify_work_stealing &&
partition_index_ == worker_index_)) {
queue_->NotifyWorkStealingWorker();
}

View File

@ -28,6 +28,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@eigen_archive//:eigen3",
"@local_tsl//tsl/profiler/lib:traceme",
"@slinky//slinky/base:thread_pool",
],
)

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "slinky/base/ref_count.h"
#include "slinky/base/thread_pool.h"
#include "xla/backends/cpu/runtime/work_queue.h"
#include "tsl/profiler/lib/traceme.h"
#define EIGEN_USE_THREADS
#include "Eigen/ThreadPool"
@ -66,10 +67,14 @@ class Task final : public SlinkyThreadPool::task {
// Runs this task by processing work items in the current thread.
TaskState Run();
// Returns true if the work queue is empty. It doesn't mean that the task is
// complete, as some threads might still be working on this task.
bool IsEmptyWorkQueue() const;
// Returns the number of workers that are currently working on this task.
int64_t num_workers() const;
bool is_empty_work_queue() const;
// Returns true if the task is done.
bool done() const final;
private:
@ -100,13 +105,13 @@ TaskState Task::Run() {
Worker w(worker_index, &work_queue_);
size_t num_processed_work_items = 0;
if (std::optional<size_t> item = w.Pop()) {
if (std::optional<size_t> item = w.Pop(/*notify_work_stealing=*/false)) {
SlinkyThreadPool::task_body body = body_;
do {
body(*item);
++num_processed_work_items;
} while ((item = w.Pop()).has_value());
} while ((item = w.Pop(/*notify_work_stealing=*/false)).has_value());
}
// The number of pending work items should never go below zero.
@ -128,7 +133,7 @@ int64_t Task::num_workers() const {
return worker_index_.load(std::memory_order_relaxed);
}
bool Task::is_empty_work_queue() const { return work_queue_.IsEmpty(); }
bool Task::IsEmptyWorkQueue() const { return work_queue_.IsEmpty(); }
bool Task::done() const {
return pending_work_items_.load(std::memory_order_acquire) == 0;
@ -231,7 +236,7 @@ slinky::ref_count<Task> SlinkyThreadPool::Impl::Dequeue() {
slinky::ref_count<Task>& task = *i;
// Task doesn't have any more work items to process.
if (ABSL_PREDICT_FALSE(task->is_empty_work_queue())) {
if (ABSL_PREDICT_FALSE(task->IsEmptyWorkQueue())) {
i = tasks_.erase(i);
continue;
}
@ -278,6 +283,7 @@ void SlinkyThreadPool::Impl::WorkOnTasks(const absl::Condition& condition) {
void SlinkyThreadPool::Impl::Await(const absl::Condition& condition) {
if (ABSL_PREDICT_FALSE(!condition.Eval())) {
tsl::profiler::TraceMe trace("SlinkyThreadPool::Await");
absl::MutexLock lock(waiter_mutex_);
waiter_mutex_.Await(condition);
}
@ -303,7 +309,7 @@ void SlinkyThreadPool::Impl::ScheduleWorkers(int64_t num_workers,
if (ABSL_PREDICT_TRUE(num_workers > 0 && CanScheduleWorkers())) {
slinky::ref_count<ScheduleState> state(
new ScheduleState(num_workers - 1, std::move(task), {this}));
threadpool_->Schedule([state = state.take()]() {
threadpool_->Schedule([state = state.take()] {
ScheduleWorkers</*release_impl_ref=*/false>(state);
});
}
@ -321,8 +327,7 @@ void SlinkyThreadPool::Impl::ScheduleWorkers(ScheduleState* context) {
for (size_t i = 0; i < kNumRecursiveWorkers; ++i) {
bool schedule_worker =
state->impl->CanScheduleWorkers() &&
!state->task->is_empty_work_queue() &&
state->impl->CanScheduleWorkers() && !state->task->IsEmptyWorkQueue() &&
state->remaining_workers.fetch_sub(1, std::memory_order_relaxed) > 0;
if (ABSL_PREDICT_TRUE(!schedule_worker)) {
@ -333,7 +338,7 @@ void SlinkyThreadPool::Impl::ScheduleWorkers(ScheduleState* context) {
// reference count to track the number of active workers.
state->impl->add_ref();
state->impl->threadpool_->Schedule(
[state = slinky::ref_count<ScheduleState>(state).take()]() {
[state = slinky::ref_count<ScheduleState>(state).take()] {
SlinkyThreadPool::Impl::ScheduleWorkers</*release_impl_ref=*/true>(
state);
});