mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
The feature was never fully finished and never got any adoption but TCPStore pays the cost of twice the number of tcp connections anyway. While the cost of all those idle connections is minimal is doesn't come for free: - It increases the likelyhood of a connection refused failure during the initialization stampede. - TCPStore uses poll for checking for socket availability which scales linearly on the number of sockets regardless of their status. Pull Request resolved: https://github.com/pytorch/pytorch/pull/105014 Approved by: https://github.com/fduwjj
214 lines
6.9 KiB
C++
214 lines
6.9 KiB
C++
#include <c10/util/irange.h>
|
|
#include "StoreTestCommon.hpp"
|
|
|
|
#include <cstdlib>
|
|
#include <future>
|
|
#include <iostream>
|
|
#include <system_error>
|
|
#include <thread>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
|
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
|
|
|
constexpr int64_t kShortStoreTimeoutMillis = 100;
|
|
constexpr int defaultTimeout = 20;
|
|
|
|
c10::intrusive_ptr<c10d::TCPStore> _createServer(
|
|
int numWorkers = 1,
|
|
int timeout = defaultTimeout) {
|
|
return c10::make_intrusive<c10d::TCPStore>(
|
|
"127.0.0.1",
|
|
c10d::TCPStoreOptions{
|
|
/* port */ 0,
|
|
/* isServer */ true,
|
|
numWorkers,
|
|
/* waitWorkers */ false,
|
|
/* timeout */ std::chrono::seconds(timeout)});
|
|
}
|
|
|
|
// Different ports for different tests.
|
|
void testHelper(const std::string& prefix = "") {
|
|
constexpr auto numThreads = 16;
|
|
constexpr auto numWorkers = numThreads + 1;
|
|
|
|
auto serverTCPStore = _createServer(numWorkers);
|
|
|
|
auto serverStore =
|
|
c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
|
|
// server store
|
|
auto serverThread = std::thread([&serverStore, &serverTCPStore] {
|
|
// Wait for all workers to join.
|
|
serverTCPStore->waitForWorkers();
|
|
|
|
// Basic set/get on the server store
|
|
c10d::test::set(*serverStore, "key0", "value0");
|
|
c10d::test::set(*serverStore, "key1", "value1");
|
|
c10d::test::set(*serverStore, "key2", "value2");
|
|
c10d::test::check(*serverStore, "key0", "value0");
|
|
c10d::test::check(*serverStore, "key1", "value1");
|
|
c10d::test::check(*serverStore, "key2", "value2");
|
|
serverStore->add("counter", 1);
|
|
auto numKeys = serverStore->getNumKeys();
|
|
// We expect 5 keys since 3 are added above, 'counter' is added by the
|
|
// helper thread, and the init key to coordinate workers.
|
|
EXPECT_EQ(numKeys, 5);
|
|
|
|
// Check compareSet, does not check return value
|
|
c10d::test::compareSet(
|
|
*serverStore, "key0", "wrongExpectedValue", "newValue");
|
|
c10d::test::check(*serverStore, "key0", "value0");
|
|
c10d::test::compareSet(*serverStore, "key0", "value0", "newValue");
|
|
c10d::test::check(*serverStore, "key0", "newValue");
|
|
|
|
auto delSuccess = serverStore->deleteKey("key0");
|
|
// Ensure that the key was successfully deleted
|
|
EXPECT_TRUE(delSuccess);
|
|
auto delFailure = serverStore->deleteKey("badKeyName");
|
|
// The key was not in the store so the delete operation should have failed
|
|
// and returned false.
|
|
EXPECT_FALSE(delFailure);
|
|
numKeys = serverStore->getNumKeys();
|
|
EXPECT_EQ(numKeys, 4);
|
|
auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
|
|
serverStore->setTimeout(timeout);
|
|
EXPECT_THROW(serverStore->get("key0"), c10::Error);
|
|
});
|
|
|
|
// Hammer on TCPStore
|
|
std::vector<std::thread> threads;
|
|
constexpr auto numIterations = 1000;
|
|
c10d::test::Semaphore sem1, sem2;
|
|
|
|
c10d::TCPStoreOptions opts{};
|
|
opts.port = serverTCPStore->getPort();
|
|
opts.numWorkers = numWorkers;
|
|
|
|
// Each thread will have a client store to send/recv data
|
|
std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
|
|
std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
|
|
for (const auto i : c10::irange(numThreads)) {
|
|
clientTCPStores.push_back(
|
|
c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
|
|
clientStores.push_back(
|
|
c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
|
|
}
|
|
|
|
std::string expectedCounterRes =
|
|
std::to_string(numThreads * numIterations + 1);
|
|
|
|
for (const auto i : c10::irange(numThreads)) {
|
|
threads.emplace_back(
|
|
std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] {
|
|
for (C10_UNUSED const auto j : c10::irange(numIterations)) {
|
|
clientStores[i]->add("counter", 1);
|
|
}
|
|
// Let each thread set and get key on its client store
|
|
std::string key = "thread_" + std::to_string(i);
|
|
for (const auto j : c10::irange(numIterations)) {
|
|
std::string val = "thread_val_" + std::to_string(j);
|
|
c10d::test::set(*clientStores[i], key, val);
|
|
c10d::test::check(*clientStores[i], key, val);
|
|
}
|
|
|
|
sem1.post();
|
|
sem2.wait();
|
|
// Check the counter results
|
|
c10d::test::check(*clientStores[i], "counter", expectedCounterRes);
|
|
// Now check other threads' written data
|
|
for (const auto j : c10::irange(numThreads)) {
|
|
if (j == i) {
|
|
continue;
|
|
}
|
|
std::string key = "thread_" + std::to_string(i);
|
|
std::string val = "thread_val_" + std::to_string(numIterations - 1);
|
|
c10d::test::check(*clientStores[i], key, val);
|
|
}
|
|
}));
|
|
}
|
|
|
|
sem1.wait(numThreads);
|
|
sem2.post(numThreads);
|
|
|
|
for (auto& thread : threads) {
|
|
thread.join();
|
|
}
|
|
|
|
serverThread.join();
|
|
|
|
// Clear the store to test that client disconnect won't shutdown the store
|
|
clientStores.clear();
|
|
clientTCPStores.clear();
|
|
|
|
// Check that the counter has the expected value
|
|
c10d::test::check(*serverStore, "counter", expectedCounterRes);
|
|
|
|
// Check that each threads' written data from the main thread
|
|
for (const auto i : c10::irange(numThreads)) {
|
|
std::string key = "thread_" + std::to_string(i);
|
|
std::string val = "thread_val_" + std::to_string(numIterations - 1);
|
|
c10d::test::check(*serverStore, key, val);
|
|
}
|
|
}
|
|
|
|
TEST(TCPStoreTest, testHelper) {
|
|
testHelper();
|
|
}
|
|
|
|
TEST(TCPStoreTest, testHelperPrefix) {
|
|
testHelper("testPrefix");
|
|
}
|
|
|
|
TEST(TCPStoreTest, testCleanShutdown) {
|
|
int numWorkers = 2;
|
|
|
|
auto serverTCPStore = std::make_unique<c10d::TCPStore>(
|
|
"127.0.0.1",
|
|
0,
|
|
numWorkers,
|
|
true,
|
|
std::chrono::seconds(defaultTimeout),
|
|
/* wait */ false);
|
|
c10d::test::set(*serverTCPStore, "key", "val");
|
|
|
|
auto clientTCPStore = c10::make_intrusive<c10d::TCPStore>(
|
|
"127.0.0.1",
|
|
c10d::TCPStoreOptions{
|
|
/* port */ serverTCPStore->getPort(),
|
|
/* isServer */ false,
|
|
numWorkers,
|
|
/* waitWorkers */ false,
|
|
/* timeout */ std::chrono::seconds(defaultTimeout)});
|
|
clientTCPStore->get("key");
|
|
|
|
auto clientThread = std::thread([&clientTCPStore] {
|
|
EXPECT_THROW(clientTCPStore->get("invalid_key"), std::system_error);
|
|
});
|
|
|
|
// start server shutdown during a client request
|
|
serverTCPStore = nullptr;
|
|
|
|
clientThread.join();
|
|
}
|
|
|
|
TEST(TCPStoreTest, testMultiTenantStores) {
|
|
c10d::TCPStoreOptions opts{};
|
|
opts.isServer = true;
|
|
opts.multiTenant = true;
|
|
|
|
// Construct two server stores on the same port.
|
|
auto store1 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
|
|
auto store2 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
|
|
|
|
// Assert that the two stores share the same server.
|
|
c10d::test::set(*store1, "key0", "value0");
|
|
c10d::test::check(*store2, "key0", "value0");
|
|
|
|
// Dispose the second instance and assert that the server is still alive.
|
|
store2.reset();
|
|
|
|
c10d::test::set(*store1, "key0", "value0");
|
|
c10d::test::check(*store1, "key0", "value0");
|
|
}
|