pytorch/test/cpp/nativert/test_mpmc_queue.cpp
Zhengxu Chen fe11d300ac [nativert] Improve MPMCQueue tests. (#153154)
Summary:
- Use std::this_thread::yield and stop busy wating.
- Sort test file orders.

Following up @swolchok's comment from https://github.com/pytorch/pytorch/pull/152837
Test Plan: CI

Differential Revision: D74402536

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153154
Approved by: https://github.com/Skylion007, https://github.com/cyyever
2025-05-09 19:25:42 +00:00

135 lines
3.8 KiB
C++

#include <atomic>
#include <thread>
#include <gtest/gtest.h>
#include <torch/nativert/detail/MPMCQueue.h>
using torch::nativert::detail::MPMCQueue;
TEST(MPMCQueueTest, EmptyQueue) {
MPMCQueue<int> queue(5);
int out = 0;
EXPECT_FALSE(queue.readIfNotEmpty(out));
}
TEST(MPMCQueueTest, SingleElement) {
MPMCQueue<int> queue(5);
EXPECT_TRUE(queue.writeIfNotFull(10));
int out = 0;
EXPECT_TRUE(queue.readIfNotEmpty(out));
EXPECT_EQ(out, 10);
}
TEST(MPMCQueueTest, MultipleElements) {
MPMCQueue<int> queue(5);
for (int i = 0; i < 5; ++i) {
EXPECT_TRUE(queue.writeIfNotFull(i));
}
for (int i = 0; i < 5; ++i) {
int out = 0;
EXPECT_TRUE(queue.readIfNotEmpty(out));
EXPECT_EQ(out, i);
}
}
TEST(MPMCQueueTest, FullQueue) {
MPMCQueue<int> queue(5);
for (int i = 0; i < 5; ++i) {
EXPECT_TRUE(queue.writeIfNotFull(i));
}
EXPECT_FALSE(queue.writeIfNotFull(10));
}
TEST(MPMCQueueTest, ConcurrentAccess) {
MPMCQueue<int> queue(10);
std::thread writer([&queue]() {
for (int i = 0; i < 5; ++i) {
queue.writeIfNotFull(i);
}
});
std::thread reader([&queue]() {
for (int i = 0; i < 5; ++i) {
int out = 0;
while (!queue.readIfNotEmpty(out)) {
// Wait until an element is available
// TODO We could provide a blocking version of read() instead of
// looping here. We only provide a non blocking wait API because
// for now the queue is paired with a semaphore in executor.
std::this_thread::yield();
}
EXPECT_LT(out, 5);
}
});
writer.join();
reader.join();
}
TEST(MPMCQueueTest, MPMCConcurrentAccess) {
const size_t queueCapacity = 100000;
const size_t numWriters = 5;
const size_t numReaders = 5;
const size_t numElementsPerWriter = 10000;
MPMCQueue<int> queue(queueCapacity);
// Writer threads
std::vector<std::thread> writers;
writers.reserve(numWriters);
for (size_t i = 0; i < numWriters; ++i) {
writers.emplace_back([&]() {
for (size_t j = 0; j < numElementsPerWriter; ++j) {
size_t value = i * numElementsPerWriter + j;
while (!queue.writeIfNotFull(static_cast<int>(value))) {
// Retry until the queue has space
// TODO We could provide a blocking version of read() instead of
// looping here. We only provide a non blocking wait API because
// for now the queue is paired with a semaphore in executor.
std::this_thread::yield();
}
}
});
}
// Reader threads
std::vector<std::thread> readers;
std::atomic<size_t> totalReadCount{0};
readers.reserve(numReaders);
for (size_t i = 0; i < numReaders; ++i) {
readers.emplace_back([&]() {
int value = 0;
while (totalReadCount < numWriters * numElementsPerWriter) {
if (queue.readIfNotEmpty(value)) {
++totalReadCount;
} else {
// TODO We could provide a blocking version of read() instead of
// looping here. We only provide a non blocking wait API because
// for now the queue is paired with a semaphore in executor.
std::this_thread::yield();
}
}
});
}
// Join all threads
for (auto& writer : writers) {
writer.join();
}
for (auto& reader : readers) {
reader.join();
}
// Verify that all elements were read
EXPECT_EQ(totalReadCount, numWriters * numElementsPerWriter);
}
TEST(MPMCQueueTest, MoveOnlyType) {
struct MoveOnly {
MoveOnly() = default;
MoveOnly(const MoveOnly&) = delete;
MoveOnly& operator=(const MoveOnly&) = delete;
MoveOnly(MoveOnly&&) = default;
MoveOnly& operator=(MoveOnly&&) = default;
~MoveOnly() = default;
};
MPMCQueue<MoveOnly> queue(5);
EXPECT_TRUE(queue.writeIfNotFull(MoveOnly()));
MoveOnly out;
EXPECT_TRUE(queue.readIfNotEmpty(out));
}