mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:cpu:ynn] Do not track work stealing workers
``` name cpu/op cpu/op vs base BM_ParallelFor/8/1/process_time [#threads=8, #threadpools=1 ] 5.470m ± 5% 5.095m ± 3% -6.87% (p=0.000 n=80) BM_ParallelFor/8/2/process_time [#threads=8, #threadpools=2 ] 2.857m ± 1% 2.595m ± 2% -9.15% (n=80) BM_ParallelFor/8/4/process_time [#threads=8, #threadpools=4 ] 1.447m ± 10% 1.328m ± 1% -8.23% (p=0.000 n=80) BM_ParallelFor/8/8/process_time [#threads=8, #threadpools=8 ] 1058.1µ ± 20% 974.5µ ± 1% -7.90% (p=0.000 n=80) BM_ParallelFor/8/16/process_time [#threads=8, #threadpools=16 ] 741.5µ ± 26% 705.8µ ± 1% -4.81% (p=0.000 n=80) BM_ParallelFor/16/1/process_time [#threads=16, #threadpools=1 ] 9.796m ± 29% 9.972m ± 2% ~ (p=0.312 n=80) BM_ParallelFor/16/2/process_time [#threads=16, #threadpools=2 ] 7.871m ± 28% 7.706m ± 1% -2.10% (p=0.030 n=80) BM_ParallelFor/16/4/process_time [#threads=16, #threadpools=4 ] 4.330m ± 2% 4.157m ± 1% -3.99% (p=0.000 n=80) BM_ParallelFor/16/8/process_time [#threads=16, #threadpools=8 ] 2.678m ± 2% 2.638m ± 1% -1.49% (p=0.014 n=80) BM_ParallelFor/16/16/process_time [#threads=16, #threadpools=16] 1.791m ± 1% 1.807m ± 1% ~ (p=0.325 n=80) BM_ParallelFor/32/1/process_time [#threads=32, #threadpools=1 ] 15.33m ± 1% 15.41m ± 1% ~ (p=0.215 n=80) BM_ParallelFor/32/2/process_time [#threads=32, #threadpools=2 ] 13.99m ± 1% 13.80m ± 2% ~ (p=0.400 n=80) BM_ParallelFor/32/4/process_time [#threads=32, #threadpools=4 ] 9.415m ± 1% 9.172m ± 1% -2.58% (p=0.000 n=80) BM_ParallelFor/32/8/process_time [#threads=32, #threadpools=8 ] 5.759m ± 1% 5.647m ± 1% -1.95% (p=0.004 n=80) BM_ParallelFor/32/16/process_time [#threads=32, #threadpools=16] 3.932m ± 1% 3.864m ± 1% -1.72% (p=0.006 n=80) geomean 4.051m 3.916m -3.32% name time/op time/op vs base BM_ParallelFor/8/1/process_time [#threads=8, #threadpools=1 ] 651.2µ ± 3% 600.3µ ± 4% -7.80% (p=0.000 n=80) BM_ParallelFor/8/2/process_time [#threads=8, #threadpools=2 ] 329.4µ ± 0% 298.6µ ± 2% -9.35% (n=80) BM_ParallelFor/8/4/process_time [#threads=8, #threadpools=4 ] 169.3µ ± 12% 155.7µ ± 1% -8.05% (p=0.000 n=80) BM_ParallelFor/8/8/process_time [#threads=8, #threadpools=8 ] 125.8µ ± 21% 115.7µ ± 1% -8.08% (p=0.000 n=80) BM_ParallelFor/8/16/process_time [#threads=8, #threadpools=16 ] 95.41µ ± 24% 89.56µ ± 1% -6.13% (p=0.000 n=80) BM_ParallelFor/16/1/process_time [#threads=16, #threadpools=1 ] 1015.8µ ± 1% 952.0µ ± 1% -6.29% (n=80) BM_ParallelFor/16/2/process_time [#threads=16, #threadpools=2 ] 556.5µ ± 1% 522.6µ ± 1% -6.09% (n=80) BM_ParallelFor/16/4/process_time [#threads=16, #threadpools=4 ] 289.7µ ± 2% 274.4µ ± 1% -5.30% (p=0.000 n=80) BM_ParallelFor/16/8/process_time [#threads=16, #threadpools=8 ] 178.8µ ± 2% 174.1µ ± 1% -2.59% (p=0.000 n=80) BM_ParallelFor/16/16/process_time [#threads=16, #threadpools=16] 123.9µ ± 2% 123.0µ ± 1% ~ (p=0.098 n=80) BM_ParallelFor/32/1/process_time [#threads=32, #threadpools=1 ] 1.526m ± 3% 1.433m ± 3% -6.07% (p=0.000 n=80) BM_ParallelFor/32/2/process_time [#threads=32, #threadpools=2 ] 835.2µ ± 2% 783.5µ ± 2% -6.19% (p=0.000 n=80) BM_ParallelFor/32/4/process_time [#threads=32, #threadpools=4 ] 471.6µ ± 2% 455.1µ ± 1% -3.52% (p=0.000 n=80) BM_ParallelFor/32/8/process_time [#threads=32, #threadpools=8 ] 296.1µ ± 2% 287.0µ ± 2% -3.08% (p=0.000 n=80) BM_ParallelFor/32/16/process_time [#threads=32, #threadpools=16] 215.0µ ± 2% 211.6µ ± 1% -1.59% (p=0.018 n=80) geomean 330.2µ 312.3µ -5.42% ``` PiperOrigin-RevId: 824259124
This commit is contained in:
parent
e65144c31f
commit
5edcd28152
|
|
@ -93,7 +93,11 @@ class Worker {
|
|||
public:
|
||||
Worker(size_t worker_index, WorkQueue* queue);
|
||||
|
||||
std::optional<size_t> Pop();
|
||||
// Pops a work item from the work queue. If `notify_work_stealing` is true,
|
||||
// the worker will notify the work queue when it switches to the work
|
||||
// stealing mode. Worker parallelization has an optimization to avoid
|
||||
// scheduling more workers if there are workers in the work stealing mode.
|
||||
std::optional<size_t> Pop(bool notify_work_stealing = true);
|
||||
|
||||
// Schedule `num_workers` workers into the Eigen thread pool that process
|
||||
// `num_work_items` parallel work items and return an async value that becomes
|
||||
|
|
@ -182,7 +186,7 @@ inline Worker::Worker(size_t worker_index, WorkQueue* queue)
|
|||
partition_index_(worker_index),
|
||||
queue_(queue) {}
|
||||
|
||||
inline std::optional<size_t> Worker::Pop() {
|
||||
inline std::optional<size_t> Worker::Pop(bool notify_work_stealing) {
|
||||
std::optional<size_t> work_item = queue_->Pop(partition_index_);
|
||||
if (ABSL_PREDICT_TRUE(work_item)) {
|
||||
return work_item;
|
||||
|
|
@ -190,7 +194,8 @@ inline std::optional<size_t> Worker::Pop() {
|
|||
|
||||
// If we didn't find a work item in the initially assigned partition, notify
|
||||
// the work queue that we are switching to work stealing mode.
|
||||
if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
|
||||
if (ABSL_PREDICT_FALSE(notify_work_stealing &&
|
||||
partition_index_ == worker_index_)) {
|
||||
queue_->NotifyWorkStealingWorker();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ cc_library(
|
|||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@eigen_archive//:eigen3",
|
||||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
"@slinky//slinky/base:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
#include "slinky/base/ref_count.h"
|
||||
#include "slinky/base/thread_pool.h"
|
||||
#include "xla/backends/cpu/runtime/work_queue.h"
|
||||
#include "tsl/profiler/lib/traceme.h"
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#include "Eigen/ThreadPool"
|
||||
|
|
@ -66,10 +67,14 @@ class Task final : public SlinkyThreadPool::task {
|
|||
// Runs this task by processing work items in the current thread.
|
||||
TaskState Run();
|
||||
|
||||
// Returns true if the work queue is empty. It doesn't mean that the task is
|
||||
// complete, as some threads might still be working on this task.
|
||||
bool IsEmptyWorkQueue() const;
|
||||
|
||||
// Returns the number of workers that are currently working on this task.
|
||||
int64_t num_workers() const;
|
||||
|
||||
bool is_empty_work_queue() const;
|
||||
// Returns true if the task is done.
|
||||
bool done() const final;
|
||||
|
||||
private:
|
||||
|
|
@ -100,13 +105,13 @@ TaskState Task::Run() {
|
|||
Worker w(worker_index, &work_queue_);
|
||||
size_t num_processed_work_items = 0;
|
||||
|
||||
if (std::optional<size_t> item = w.Pop()) {
|
||||
if (std::optional<size_t> item = w.Pop(/*notify_work_stealing=*/false)) {
|
||||
SlinkyThreadPool::task_body body = body_;
|
||||
|
||||
do {
|
||||
body(*item);
|
||||
++num_processed_work_items;
|
||||
} while ((item = w.Pop()).has_value());
|
||||
} while ((item = w.Pop(/*notify_work_stealing=*/false)).has_value());
|
||||
}
|
||||
|
||||
// The number of pending work items should never go below zero.
|
||||
|
|
@ -128,7 +133,7 @@ int64_t Task::num_workers() const {
|
|||
return worker_index_.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
bool Task::is_empty_work_queue() const { return work_queue_.IsEmpty(); }
|
||||
bool Task::IsEmptyWorkQueue() const { return work_queue_.IsEmpty(); }
|
||||
|
||||
bool Task::done() const {
|
||||
return pending_work_items_.load(std::memory_order_acquire) == 0;
|
||||
|
|
@ -231,7 +236,7 @@ slinky::ref_count<Task> SlinkyThreadPool::Impl::Dequeue() {
|
|||
slinky::ref_count<Task>& task = *i;
|
||||
|
||||
// Task doesn't have any more work items to process.
|
||||
if (ABSL_PREDICT_FALSE(task->is_empty_work_queue())) {
|
||||
if (ABSL_PREDICT_FALSE(task->IsEmptyWorkQueue())) {
|
||||
i = tasks_.erase(i);
|
||||
continue;
|
||||
}
|
||||
|
|
@ -278,6 +283,7 @@ void SlinkyThreadPool::Impl::WorkOnTasks(const absl::Condition& condition) {
|
|||
|
||||
void SlinkyThreadPool::Impl::Await(const absl::Condition& condition) {
|
||||
if (ABSL_PREDICT_FALSE(!condition.Eval())) {
|
||||
tsl::profiler::TraceMe trace("SlinkyThreadPool::Await");
|
||||
absl::MutexLock lock(waiter_mutex_);
|
||||
waiter_mutex_.Await(condition);
|
||||
}
|
||||
|
|
@ -303,7 +309,7 @@ void SlinkyThreadPool::Impl::ScheduleWorkers(int64_t num_workers,
|
|||
if (ABSL_PREDICT_TRUE(num_workers > 0 && CanScheduleWorkers())) {
|
||||
slinky::ref_count<ScheduleState> state(
|
||||
new ScheduleState(num_workers - 1, std::move(task), {this}));
|
||||
threadpool_->Schedule([state = state.take()]() {
|
||||
threadpool_->Schedule([state = state.take()] {
|
||||
ScheduleWorkers</*release_impl_ref=*/false>(state);
|
||||
});
|
||||
}
|
||||
|
|
@ -321,8 +327,7 @@ void SlinkyThreadPool::Impl::ScheduleWorkers(ScheduleState* context) {
|
|||
|
||||
for (size_t i = 0; i < kNumRecursiveWorkers; ++i) {
|
||||
bool schedule_worker =
|
||||
state->impl->CanScheduleWorkers() &&
|
||||
!state->task->is_empty_work_queue() &&
|
||||
state->impl->CanScheduleWorkers() && !state->task->IsEmptyWorkQueue() &&
|
||||
state->remaining_workers.fetch_sub(1, std::memory_order_relaxed) > 0;
|
||||
|
||||
if (ABSL_PREDICT_TRUE(!schedule_worker)) {
|
||||
|
|
@ -333,7 +338,7 @@ void SlinkyThreadPool::Impl::ScheduleWorkers(ScheduleState* context) {
|
|||
// reference count to track the number of active workers.
|
||||
state->impl->add_ref();
|
||||
state->impl->threadpool_->Schedule(
|
||||
[state = slinky::ref_count<ScheduleState>(state).take()]() {
|
||||
[state = slinky::ref_count<ScheduleState>(state).take()] {
|
||||
SlinkyThreadPool::Impl::ScheduleWorkers</*release_impl_ref=*/true>(
|
||||
state);
|
||||
});
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user