diff --git a/build_variables.bzl b/build_variables.bzl index e1207a43ee4..4c45717cb3f 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -528,6 +528,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", "torch/csrc/distributed/c10d/Store.cpp", "torch/csrc/distributed/c10d/TCPStore.cpp", + "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/Utils.cpp", "torch/csrc/distributed/c10d/comm.cpp", "torch/csrc/distributed/c10d/debug.cpp", diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index a4554d73d7e..c2cfebd0d7f 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -29,114 +30,6 @@ namespace c10d { namespace detail { namespace { - -// Abstract base class to handle thread state for TCPStoreMasterDaemon. -// Contains the windows/unix implementations to signal a -// shutdown sequence for the thread -class BackgroundThread { - public: - explicit BackgroundThread(Socket&& storeListenSocket); - - virtual ~BackgroundThread() = 0; - - protected: - void dispose(); - - Socket storeListenSocket_; - std::thread daemonThread_{}; - std::vector sockets_{}; -#ifdef _WIN32 - const std::chrono::milliseconds checkTimeout_ = std::chrono::milliseconds{10}; - HANDLE ghStopEvent_{}; -#else - std::array controlPipeFd_{{-1, -1}}; -#endif - - private: - // Initialization for shutdown signal - void initStopSignal(); - // Triggers the shutdown signal - void stop(); - // Joins the thread - void join(); - // Clean up the shutdown signal - void closeStopSignal(); -}; - -// Background thread parent class methods -BackgroundThread::BackgroundThread(Socket&& storeListenSocket) - : storeListenSocket_{std::move(storeListenSocket)} { - // Signal instance destruction to the daemon thread. - initStopSignal(); -} - -BackgroundThread::~BackgroundThread() = default; - -// WARNING: -// Since we rely on the subclass for the daemon thread clean-up, we cannot -// destruct our member variables in the destructor. The subclass must call -// dispose() in its own destructor. -void BackgroundThread::dispose() { - // Stop the run - stop(); - // Join the thread - join(); - // Close unclosed sockets - sockets_.clear(); - // Now close the rest control pipe - closeStopSignal(); -} - -void BackgroundThread::join() { - daemonThread_.join(); -} - -#ifdef _WIN32 -void BackgroundThread::initStopSignal() { - ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL); - if (ghStopEvent_ == NULL) { - TORCH_CHECK( - false, - "Failed to create the control pipe to start the " - "BackgroundThread run"); - } -} - -void BackgroundThread::closeStopSignal() { - CloseHandle(ghStopEvent_); -} - -void BackgroundThread::stop() { - SetEvent(ghStopEvent_); -} -#else -void BackgroundThread::initStopSignal() { - if (pipe(controlPipeFd_.data()) == -1) { - TORCH_CHECK( - false, - "Failed to create the control pipe to start the " - "BackgroundThread run"); - } -} - -void BackgroundThread::closeStopSignal() { - for (int fd : controlPipeFd_) { - if (fd != -1) { - ::close(fd); - } - } -} - -void BackgroundThread::stop() { - if (controlPipeFd_[1] != -1) { - ::write(controlPipeFd_[1], "\0", 1); - // close the write end of the pipe - ::close(controlPipeFd_[1]); - controlPipeFd_[1] = -1; - } -} -#endif - enum class QueryType : uint8_t { SET, COMPARE_SET, diff --git a/torch/csrc/distributed/c10d/TCPStoreBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreBackend.cpp new file mode 100644 index 00000000000..47e28b26c87 --- /dev/null +++ b/torch/csrc/distributed/c10d/TCPStoreBackend.cpp @@ -0,0 +1,105 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#include +#else +#include +#include +#endif + +#ifdef _WIN32 +#include +#else +#include +#endif + +#include + +namespace c10d { +namespace detail { + +// Background thread parent class methods +BackgroundThread::BackgroundThread(Socket&& storeListenSocket) + : storeListenSocket_{std::move(storeListenSocket)} { + // Signal instance destruction to the daemon thread. + initStopSignal(); +} + +BackgroundThread::~BackgroundThread() = default; + +// WARNING: +// Since we rely on the subclass for the daemon thread clean-up, we cannot +// destruct our member variables in the destructor. The subclass must call +// dispose() in its own destructor. +void BackgroundThread::dispose() { + // Stop the run + stop(); + // Join the thread + join(); + // Close unclosed sockets + sockets_.clear(); + // Now close the rest control pipe + closeStopSignal(); +} + +void BackgroundThread::join() { + daemonThread_.join(); +} + +#ifdef _WIN32 +void BackgroundThread::initStopSignal() { + ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL); + if (ghStopEvent_ == NULL) { + TORCH_CHECK( + false, + "Failed to create the control pipe to start the " + "BackgroundThread run"); + } +} + +void BackgroundThread::closeStopSignal() { + CloseHandle(ghStopEvent_); +} + +void BackgroundThread::stop() { + SetEvent(ghStopEvent_); +} +#else +void BackgroundThread::initStopSignal() { + if (pipe(controlPipeFd_.data()) == -1) { + TORCH_CHECK( + false, + "Failed to create the control pipe to start the " + "BackgroundThread run"); + } +} + +void BackgroundThread::closeStopSignal() { + for (int fd : controlPipeFd_) { + if (fd != -1) { + ::close(fd); + } + } +} + +void BackgroundThread::stop() { + if (controlPipeFd_[1] != -1) { + ::write(controlPipeFd_[1], "\0", 1); + // close the write end of the pipe + ::close(controlPipeFd_[1]); + controlPipeFd_[1] = -1; + } +} +#endif + +} // namespace detail +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/TCPStoreBackend.hpp b/torch/csrc/distributed/c10d/TCPStoreBackend.hpp new file mode 100644 index 00000000000..71b952e1f5e --- /dev/null +++ b/torch/csrc/distributed/c10d/TCPStoreBackend.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include + +#include + +#ifdef _WIN32 +#include +#include +#else +#include +#include +#endif + +namespace c10d { +namespace detail { + +// Abstract base class to handle thread state for TCPStoreMasterDaemon. +// Contains the windows/unix implementations to signal a +// shutdown sequence for the thread +class BackgroundThread { + public: + explicit BackgroundThread(Socket&& storeListenSocket); + + virtual ~BackgroundThread() = 0; + + protected: + void dispose(); + + Socket storeListenSocket_; + std::thread daemonThread_{}; + std::vector sockets_{}; +#ifdef _WIN32 + const std::chrono::milliseconds checkTimeout_ = std::chrono::milliseconds{10}; + HANDLE ghStopEvent_{}; +#else + std::array controlPipeFd_{{-1, -1}}; +#endif + + private: + // Initialization for shutdown signal + void initStopSignal(); + // Triggers the shutdown signal + void stop(); + // Joins the thread + void join(); + // Clean up the shutdown signal + void closeStopSignal(); +}; + +} // namespace detail +} // namespace c10d