mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revise the socket implementation of c10d (#68226)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68226 **Note that this PR is unusually big due to the urgency of the changes. Please reach out to me in case you wish to have a "pair" review.** This PR introduces a major refactoring of the socket implementation of the C10d library. A big portion of the logic is now contained in the `Socket` class and a follow-up PR will further consolidate the remaining parts. As of today the changes in this PR offer: - significantly better error handling and much more verbose logging (see the example output below) - explicit support for IPv6 and dual-stack sockets - correct handling of signal interrupts - better Windows support A follow-up PR will consolidate `send`/`recv` logic into `Socket` and fully migrate to non-blocking sockets. ## Example Output ``` [I logging.h:21] The client socket will attempt to connect to an IPv6 address on (127.0.0.1, 29501). [I logging.h:21] The client socket is attempting to connect to [localhost]:29501. [W logging.h:28] The server socket on [localhost]:29501 is not yet listening (Error: 111 - Connection refused), retrying... [I logging.h:21] The server socket will attempt to listen on an IPv6 address. [I logging.h:21] The server socket is attempting to listen on [::]:29501. [I logging.h:21] The server socket has started to listen on [::]:29501. [I logging.h:21] The client socket will attempt to connect to an IPv6 address on (127.0.0.1, 29501). [I logging.h:21] The client socket is attempting to connect to [localhost]:29501. [I logging.h:21] The client socket has connected to [localhost]:29501 on [localhost]:42650. [I logging.h:21] The server socket on [::]:29501 has accepted a connection from [localhost]:42650. [I logging.h:21] The client socket has connected to [localhost]:29501 on [localhost]:42722. [I logging.h:21] The server socket on [::]:29501 has accepted a connection from [localhost]:42722. [I logging.h:21] The client socket will attempt to connect to an IPv6 address on (127.0.0.1, 29501). [I logging.h:21] The client socket is attempting to connect to [localhost]:29501. [I logging.h:21] The client socket has connected to [localhost]:29501 on [localhost]:42724. [I logging.h:21] The server socket on [::]:29501 has accepted a connection from [localhost]:42724. [I logging.h:21] The client socket will attempt to connect to an IPv6 address on (127.0.0.1, 29501). [I logging.h:21] The client socket is attempting to connect to [localhost]:29501. [I logging.h:21] The client socket has connected to [localhost]:29501 on [localhost]:42726. [I logging.h:21] The server socket on [::]:29501 has accepted a connection from [localhost]:42726. ``` ghstack-source-id: 143501987 Test Plan: Run existing unit and integration tests on devserver, Fedora, Ubuntu, macOS Big Sur, Windows 10. Reviewed By: Babar, wilson100hong, mrshenli Differential Revision: D32372333 fbshipit-source-id: 2204ffa28ed0d3683a9cb3ebe1ea8d92a831325a
This commit is contained in:
parent
4c346bd073
commit
6e640a0acf
1
setup.py
1
setup.py
|
|
@ -1005,6 +1005,7 @@ if __name__ == '__main__':
|
|||
'include/torch/csrc/autograd/generated/*.h',
|
||||
'include/torch/csrc/autograd/utils/*.h',
|
||||
'include/torch/csrc/cuda/*.h',
|
||||
'include/torch/csrc/distributed/c10d/exception.h',
|
||||
'include/torch/csrc/jit/*.h',
|
||||
'include/torch/csrc/jit/backends/*.h',
|
||||
'include/torch/csrc/jit/generated/*.h',
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class DistributedUtilTest(TestCase):
|
|||
server_port=pick_free_port,
|
||||
timeout=1,
|
||||
)
|
||||
with self.assertRaises(IOError):
|
||||
with self.assertRaises(RuntimeError):
|
||||
create_c10d_store(
|
||||
is_server=True, server_addr=server_addr, server_port=store1.port
|
||||
)
|
||||
|
|
@ -142,7 +142,7 @@ class DistributedUtilTest(TestCase):
|
|||
port = sock.getsockname()[1]
|
||||
# on the worker port conflict shouldn't matter, it should just timeout
|
||||
# since we never created a server
|
||||
with self.assertRaises(IOError):
|
||||
with self.assertRaises(TimeoutError):
|
||||
create_c10d_store(
|
||||
is_server=False,
|
||||
server_addr=socket.gethostname(),
|
||||
|
|
|
|||
|
|
@ -157,10 +157,7 @@ class TCPStoreTest(TestCase, StoreTestBase):
|
|||
return store
|
||||
|
||||
def test_address_already_in_use(self):
|
||||
if sys.platform == "win32":
|
||||
err_msg_reg = "Only one usage of each socket address*"
|
||||
else:
|
||||
err_msg_reg = "^Address already in use$"
|
||||
err_msg_reg = "^The server socket has failed to listen on any local "
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg_reg):
|
||||
addr = DEFAULT_HOSTNAME
|
||||
port = common.find_free_port()
|
||||
|
|
|
|||
|
|
@ -384,24 +384,26 @@ libtorch_core_sources = sorted(
|
|||
|
||||
# These files are the only ones that are supported on Windows.
|
||||
libtorch_distributed_base_sources = [
|
||||
"torch/csrc/distributed/c10d/frontend.cpp",
|
||||
"torch/csrc/distributed/c10d/comm.cpp",
|
||||
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
|
||||
"torch/csrc/distributed/c10d/FileStore.cpp",
|
||||
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
|
||||
"torch/csrc/distributed/c10d/logger.cpp",
|
||||
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/PrefixStore.cpp",
|
||||
"torch/csrc/distributed/c10d/ProcessGroup.cpp",
|
||||
"torch/csrc/distributed/c10d/ProcessGroupGloo.cpp",
|
||||
"torch/csrc/distributed/c10d/ProcessGroupMPI.cpp",
|
||||
"torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization.cpp",
|
||||
"torch/csrc/distributed/c10d/reducer.cpp",
|
||||
"torch/csrc/distributed/c10d/sequence_num.cpp",
|
||||
"torch/csrc/distributed/c10d/Store.cpp",
|
||||
"torch/csrc/distributed/c10d/TCPStore.cpp",
|
||||
"torch/csrc/distributed/c10d/Utils.cpp",
|
||||
"torch/csrc/distributed/c10d/comm.cpp",
|
||||
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
|
||||
"torch/csrc/distributed/c10d/exception.cpp",
|
||||
"torch/csrc/distributed/c10d/frontend.cpp",
|
||||
"torch/csrc/distributed/c10d/logger.cpp",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization.cpp",
|
||||
"torch/csrc/distributed/c10d/reducer.cpp",
|
||||
"torch/csrc/distributed/c10d/sequence_num.cpp",
|
||||
"torch/csrc/distributed/c10d/socket.cpp",
|
||||
]
|
||||
|
||||
# These files are only supported on Linux (and others) but not on Windows.
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
#include <system_error>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <mutex>
|
||||
|
|
@ -15,6 +16,10 @@
|
|||
#include <c10/util/StringUtil.h>
|
||||
#include <ATen/detail/FunctionTraits.h>
|
||||
|
||||
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
|
||||
#include <torch/csrc/distributed/c10d/exception.h>
|
||||
#endif
|
||||
|
||||
static inline void PyErr_SetString(PyObject* type, const std::string& message) {
|
||||
PyErr_SetString(type, message.c_str());
|
||||
}
|
||||
|
|
@ -48,7 +53,7 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
|
|||
try {
|
||||
|
||||
// Only catch torch-specific exceptions
|
||||
#define CATCH_TH_ERRORS(retstmnt) \
|
||||
#define CATCH_CORE_ERRORS(retstmnt) \
|
||||
catch (python_error & e) { \
|
||||
e.restore(); \
|
||||
retstmnt; \
|
||||
|
|
@ -83,12 +88,27 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
|
|||
PyErr_SetString(PyExc_RuntimeError, torch::processErrorMsg(msg)); \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (torch::PyTorchError & e) { \
|
||||
catch (torch::PyTorchError& e) { \
|
||||
auto msg = torch::processErrorMsg(e.what()); \
|
||||
PyErr_SetString(e.python_type(), msg); \
|
||||
retstmnt; \
|
||||
}
|
||||
|
||||
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
|
||||
#define CATCH_C10D_ERRORS(retstmnt) \
|
||||
catch (const c10d::TimeoutError& e) { \
|
||||
auto msg = torch::processErrorMsg(e.what()); \
|
||||
PyErr_SetString(PyExc_TimeoutError, msg); \
|
||||
retstmnt; \
|
||||
}
|
||||
#else
|
||||
#define CATCH_C10D_ERRORS(retstmnt)
|
||||
#endif
|
||||
|
||||
#define CATCH_TH_ERRORS(retstmnt) \
|
||||
CATCH_CORE_ERRORS(retstmnt) \
|
||||
CATCH_C10D_ERRORS(retstmnt)
|
||||
|
||||
#define CATCH_ALL_ERRORS(retstmnt) \
|
||||
CATCH_TH_ERRORS(retstmnt) \
|
||||
catch (const std::exception& e) { \
|
||||
|
|
|
|||
|
|
@ -23,66 +23,27 @@
|
|||
#include <c10d/UnixSockUtils.hpp>
|
||||
#endif
|
||||
|
||||
#include <c10d/socket.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace detail {
|
||||
namespace {
|
||||
|
||||
// Offers RAII for TCP sockets.
|
||||
class TCPSocket {
|
||||
public:
|
||||
TCPSocket() noexcept = default;
|
||||
|
||||
/* implicit */ TCPSocket(int handle) noexcept : handle_{handle} {}
|
||||
|
||||
TCPSocket(const TCPSocket& other) = delete;
|
||||
|
||||
TCPSocket& operator=(const TCPSocket& other) = delete;
|
||||
|
||||
TCPSocket(TCPSocket&& other) noexcept : handle_{other.handle_} {
|
||||
other.handle_ = c10::nullopt;
|
||||
}
|
||||
|
||||
TCPSocket& operator=(TCPSocket&& other) noexcept {
|
||||
closeSocket();
|
||||
|
||||
handle_ = std::exchange(other.handle_, c10::nullopt);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
~TCPSocket() {
|
||||
closeSocket();
|
||||
}
|
||||
|
||||
int handle() const noexcept {
|
||||
return handle_.value_or(-1);
|
||||
}
|
||||
|
||||
private:
|
||||
void closeSocket() noexcept {
|
||||
if (handle_) {
|
||||
tcputil::closeSocket(*handle_);
|
||||
}
|
||||
}
|
||||
|
||||
c10::optional<int> handle_{};
|
||||
};
|
||||
|
||||
// Abstract base class to handle thread state for TCPStoreMasterDaemon and
|
||||
// TCPStoreWorkerDaemon. Contains the windows/unix implementations to signal a
|
||||
// shutdown sequence for the thread
|
||||
class BackgroundThread {
|
||||
public:
|
||||
explicit BackgroundThread(TCPSocket&& storeListenSocket);
|
||||
explicit BackgroundThread(Socket&& storeListenSocket);
|
||||
|
||||
virtual ~BackgroundThread() = 0;
|
||||
|
||||
protected:
|
||||
void dispose();
|
||||
|
||||
TCPSocket storeListenSocket_;
|
||||
Socket storeListenSocket_;
|
||||
std::thread daemonThread_{};
|
||||
std::vector<TCPSocket> sockets_{};
|
||||
std::vector<Socket> sockets_{};
|
||||
#ifdef _WIN32
|
||||
const std::chrono::milliseconds checkTimeout_ = std::chrono::milliseconds{10};
|
||||
HANDLE ghStopEvent_{};
|
||||
|
|
@ -102,7 +63,7 @@ class BackgroundThread {
|
|||
};
|
||||
|
||||
// Background thread parent class methods
|
||||
BackgroundThread::BackgroundThread(TCPSocket&& storeListenSocket)
|
||||
BackgroundThread::BackgroundThread(Socket&& storeListenSocket)
|
||||
: storeListenSocket_{std::move(storeListenSocket)} {
|
||||
// Signal instance destruction to the daemon thread.
|
||||
initStopSignal();
|
||||
|
|
@ -200,7 +161,7 @@ enum class WatchResponseType : uint8_t {
|
|||
// Separate thread that is only launched on master
|
||||
class TCPStoreMasterDaemon : public BackgroundThread {
|
||||
public:
|
||||
explicit TCPStoreMasterDaemon(TCPSocket&& storeListenSocket);
|
||||
explicit TCPStoreMasterDaemon(Socket&& storeListenSocket);
|
||||
|
||||
~TCPStoreMasterDaemon() override;
|
||||
|
||||
|
|
@ -241,7 +202,7 @@ class TCPStoreMasterDaemon : public BackgroundThread {
|
|||
};
|
||||
|
||||
// Simply start the daemon thread
|
||||
TCPStoreMasterDaemon::TCPStoreMasterDaemon(TCPSocket&& storeListenSocket)
|
||||
TCPStoreMasterDaemon::TCPStoreMasterDaemon(Socket&& storeListenSocket)
|
||||
: BackgroundThread{std::move(storeListenSocket)} {
|
||||
daemonThread_ = std::thread{&TCPStoreMasterDaemon::run, this};
|
||||
}
|
||||
|
|
@ -270,7 +231,6 @@ void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
|
|||
// exception, other connections will get an exception once they try to
|
||||
// use the store. We will go ahead and close this connection whenever
|
||||
// we hit an exception here.
|
||||
tcputil::closeSocket(fds[fdIdx].fd);
|
||||
|
||||
// Remove all the tracking state of the close FD
|
||||
for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
|
||||
|
|
@ -562,8 +522,7 @@ void TCPStoreMasterDaemon::run() {
|
|||
"Unexpected poll revent on the master's listening socket: " +
|
||||
std::to_string(fds[0].revents));
|
||||
}
|
||||
TCPSocket socket =
|
||||
std::get<0>(tcputil::accept(storeListenSocket_.handle()));
|
||||
Socket socket = storeListenSocket_.accept();
|
||||
int rawSocket = socket.handle();
|
||||
sockets_.emplace_back(std::move(socket));
|
||||
tcputil::addPollfd(fds, rawSocket, POLLIN);
|
||||
|
|
@ -597,8 +556,7 @@ void TCPStoreMasterDaemon::run() {
|
|||
"Unexpected poll revent on the master's listening socket: " +
|
||||
std::to_string(fds[0].revents));
|
||||
}
|
||||
TCPSocket socket =
|
||||
std::get<0>(tcputil::accept(storeListenSocket_.handle()));
|
||||
Socket socket = storeListenSocket_.accept();
|
||||
int rawSocket = socket.handle();
|
||||
sockets_.emplace_back(std::move(socket));
|
||||
tcputil::addPollfd(fds, rawSocket, POLLIN);
|
||||
|
|
@ -626,7 +584,7 @@ void TCPStoreMasterDaemon::run() {
|
|||
// Right now only handles callbacks registered from watchKey()
|
||||
class TCPStoreWorkerDaemon : public BackgroundThread {
|
||||
public:
|
||||
explicit TCPStoreWorkerDaemon(TCPSocket&& listenSocket);
|
||||
explicit TCPStoreWorkerDaemon(Socket&& listenSocket);
|
||||
~TCPStoreWorkerDaemon() override;
|
||||
// Set the callback to run key change
|
||||
void setCallback(std::string key, WatchKeyCallback cb);
|
||||
|
|
@ -657,7 +615,7 @@ class TCPStoreWorkerDaemon : public BackgroundThread {
|
|||
};
|
||||
|
||||
// TCPStoreListener class methods
|
||||
TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(TCPSocket&& listenSocket)
|
||||
TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(Socket&& listenSocket)
|
||||
: BackgroundThread{std::move(listenSocket)} {
|
||||
daemonThread_ = std::thread{&TCPStoreWorkerDaemon::run, this};
|
||||
}
|
||||
|
|
@ -806,10 +764,9 @@ std::mutex TCPServer::cache_mutex_{};
|
|||
|
||||
std::shared_ptr<TCPServer> TCPServer::start(const TCPStoreOptions& opts) {
|
||||
auto startCore = [&opts]() {
|
||||
TCPSocket socket{};
|
||||
std::uint16_t port{};
|
||||
Socket socket = Socket::listen(opts.port);
|
||||
|
||||
std::tie(socket, port) = tcputil::listen(opts.port);
|
||||
std::uint16_t port = socket.port();
|
||||
|
||||
auto daemon = std::make_unique<TCPStoreMasterDaemon>(std::move(socket));
|
||||
|
||||
|
|
@ -881,17 +838,19 @@ class TCPClient {
|
|||
|
||||
void setTimeout(std::chrono::milliseconds value);
|
||||
|
||||
explicit TCPClient(TCPSocket&& socket) : socket_{std::move(socket)} {}
|
||||
explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {}
|
||||
|
||||
private:
|
||||
TCPSocket socket_;
|
||||
Socket socket_;
|
||||
};
|
||||
|
||||
std::unique_ptr<TCPClient> TCPClient::connect(
|
||||
const SocketAddress& addr,
|
||||
const TCPStoreOptions& opts) {
|
||||
TCPSocket socket =
|
||||
tcputil::connect(addr.host, addr.port, /* wait */ true, opts.timeout);
|
||||
auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
|
||||
Socket socket = Socket::connect(addr.host,
|
||||
addr.port,
|
||||
SocketOptions{}.connect_timeout(timeout));
|
||||
|
||||
return std::make_unique<TCPClient>(std::move(socket));
|
||||
}
|
||||
|
|
@ -962,8 +921,10 @@ class TCPCallbackClient {
|
|||
std::unique_ptr<TCPCallbackClient> TCPCallbackClient::connect(
|
||||
const SocketAddress& addr,
|
||||
const TCPStoreOptions& opts) {
|
||||
TCPSocket socket =
|
||||
tcputil::connect(addr.host, addr.port, /*wait*/ true, opts.timeout);
|
||||
auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
|
||||
Socket socket = Socket::connect(addr.host,
|
||||
addr.port,
|
||||
SocketOptions{}.connect_timeout(timeout));
|
||||
|
||||
int rawSocket = socket.handle();
|
||||
|
||||
|
|
@ -988,6 +949,8 @@ void TCPCallbackClient::setCallback(
|
|||
|
||||
} // namespace detail
|
||||
|
||||
using detail::Socket;
|
||||
|
||||
// TCPStore class methods
|
||||
TCPStore::TCPStore(
|
||||
const std::string& masterAddr,
|
||||
|
|
@ -1010,7 +973,7 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
|
|||
: Store{opts.timeout},
|
||||
addr_{std::move(host)},
|
||||
numWorkers_{opts.numWorkers} {
|
||||
tcputil::socketInitialize();
|
||||
Socket::initialize();
|
||||
|
||||
if (opts.isServer) {
|
||||
server_ = detail::TCPServer::start(opts);
|
||||
|
|
|
|||
|
|
@ -5,16 +5,8 @@
|
|||
namespace c10d {
|
||||
namespace tcputil {
|
||||
|
||||
#define AF_SELECTED AF_UNSPEC
|
||||
#define CONNECT_SOCKET_OFFSET 2
|
||||
|
||||
inline void closeSocket(int socket) { ::close(socket); }
|
||||
|
||||
inline int setSocketAddrReUse(int socket) {
|
||||
int optval = 1;
|
||||
return ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int));
|
||||
}
|
||||
|
||||
inline int poll(struct pollfd *fds, unsigned long nfds, int timeout) {
|
||||
return ::poll(fds, nfds, timeout);
|
||||
}
|
||||
|
|
@ -24,70 +16,6 @@ inline void addPollfd(std::vector<struct pollfd> &fds, int socket,
|
|||
fds.push_back({.fd = socket, .events = events});
|
||||
}
|
||||
|
||||
inline void waitSocketConnected(
|
||||
int socket,
|
||||
struct ::addrinfo *nextAddr,
|
||||
std::chrono::milliseconds timeout,
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> startTime) {
|
||||
SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, O_NONBLOCK));
|
||||
|
||||
int ret = ::connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen);
|
||||
|
||||
if (ret != 0 && errno != EINPROGRESS) {
|
||||
throw std::system_error(errno, std::system_category());
|
||||
}
|
||||
|
||||
struct ::pollfd pfd;
|
||||
pfd.fd = socket;
|
||||
pfd.events = POLLOUT;
|
||||
|
||||
int64_t pollTimeout = -1;
|
||||
if (timeout != kNoTimeout) {
|
||||
// calculate remaining time and use that as timeout for poll()
|
||||
const auto elapsed = std::chrono::high_resolution_clock::now() - startTime;
|
||||
const auto remaining =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(timeout) -
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(elapsed);
|
||||
pollTimeout = std::max(static_cast<int64_t>(0),
|
||||
static_cast<int64_t>(remaining.count()));
|
||||
}
|
||||
int numReady = ::poll(&pfd, 1, pollTimeout);
|
||||
if (numReady < 0) {
|
||||
throw std::system_error(errno, std::system_category());
|
||||
} else if (numReady == 0) {
|
||||
errno = 0;
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
c10::str(
|
||||
kConnectTimeoutMsg,
|
||||
" Polled for ",
|
||||
pollTimeout,
|
||||
" ms with original timeout of ",
|
||||
timeout.count(),
|
||||
" ms."));
|
||||
}
|
||||
|
||||
socklen_t errLen = sizeof(errno);
|
||||
errno = 0;
|
||||
::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &errLen);
|
||||
|
||||
// `errno` is set when:
|
||||
// 1. `getsockopt` has failed
|
||||
// 2. there is awaiting error in the socket
|
||||
// (the error is saved to the `errno` variable)
|
||||
if (errno != 0) {
|
||||
throw std::system_error(errno, std::system_category());
|
||||
}
|
||||
|
||||
// Disable non-blocking mode
|
||||
int flags;
|
||||
SYSCHECK_ERR_RETURN_NEG1(flags = ::fcntl(socket, F_GETFL));
|
||||
SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK)));
|
||||
}
|
||||
|
||||
// Linux socket does not need init libs first
|
||||
inline void socketInitialize() {}
|
||||
|
||||
inline struct ::pollfd getPollfd(int socket, short events) {
|
||||
struct ::pollfd res = {.fd = socket, .events = events};
|
||||
return res;
|
||||
|
|
|
|||
|
|
@ -1,16 +1,5 @@
|
|||
#ifdef _WIN32
|
||||
#include <c10d/WinSockUtils.hpp>
|
||||
#else
|
||||
#include <arpa/inet.h>
|
||||
#include <c10d/UnixSockUtils.hpp>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/poll.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include <c10d/Utils.hpp>
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
|
@ -97,298 +86,4 @@ std::vector<at::Tensor> getTensorShapes(
|
|||
return shapeTensors;
|
||||
}
|
||||
|
||||
namespace tcputil {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int LISTEN_QUEUE_SIZE = 2048;
|
||||
|
||||
void setSocketNoDelay(int socket) {
|
||||
int flag = 1;
|
||||
socklen_t optlen = sizeof(flag);
|
||||
SYSCHECK_ERR_RETURN_NEG1(
|
||||
setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen));
|
||||
}
|
||||
|
||||
PortType getSocketPort(int fd) {
|
||||
PortType listenPort;
|
||||
struct ::sockaddr_storage addrStorage;
|
||||
socklen_t addrLen = sizeof(addrStorage);
|
||||
SYSCHECK_ERR_RETURN_NEG1(getsockname(
|
||||
fd, reinterpret_cast<struct ::sockaddr*>(&addrStorage), &addrLen));
|
||||
|
||||
if (addrStorage.ss_family == AF_INET) {
|
||||
struct ::sockaddr_in* addr =
|
||||
reinterpret_cast<struct ::sockaddr_in*>(&addrStorage);
|
||||
listenPort = ntohs(addr->sin_port);
|
||||
|
||||
} else if (addrStorage.ss_family == AF_INET6) { // AF_INET6
|
||||
struct ::sockaddr_in6* addr =
|
||||
reinterpret_cast<struct ::sockaddr_in6*>(&addrStorage);
|
||||
listenPort = ntohs(addr->sin6_port);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported protocol");
|
||||
}
|
||||
return listenPort;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string sockaddrToString(struct ::sockaddr* addr) {
|
||||
char address[INET6_ADDRSTRLEN + 1];
|
||||
if (addr->sa_family == AF_INET) {
|
||||
struct ::sockaddr_in* s = reinterpret_cast<struct ::sockaddr_in*>(addr);
|
||||
SYSCHECK(
|
||||
::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN),
|
||||
__output != nullptr)
|
||||
address[INET_ADDRSTRLEN] = '\0';
|
||||
} else if (addr->sa_family == AF_INET6) {
|
||||
struct ::sockaddr_in6* s = reinterpret_cast<struct ::sockaddr_in6*>(addr);
|
||||
SYSCHECK(
|
||||
::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN),
|
||||
__output != nullptr)
|
||||
address[INET6_ADDRSTRLEN] = '\0';
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported protocol");
|
||||
}
|
||||
return address;
|
||||
}
|
||||
|
||||
// listen, connect and accept
|
||||
std::pair<int, PortType> listen(PortType port) {
|
||||
struct ::addrinfo hints, *res = NULL;
|
||||
std::memset(&hints, 0x00, sizeof(hints));
|
||||
hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
|
||||
hints.ai_family = AF_SELECTED; // IPv4 on Windows, IPv4/6 on Linux
|
||||
hints.ai_socktype = SOCK_STREAM; // TCP
|
||||
|
||||
// `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
|
||||
// by editing `/etc/gai.conf`. so there is no need to manual sorting
|
||||
// or protocol preference.
|
||||
int err = ::getaddrinfo(nullptr, std::to_string(port).data(), &hints, &res);
|
||||
if (err != 0 || !res) {
|
||||
throw std::invalid_argument(
|
||||
"cannot find host to listen on: " + std::string(gai_strerror(err)));
|
||||
}
|
||||
|
||||
std::shared_ptr<struct ::addrinfo> addresses(
|
||||
res, [](struct ::addrinfo* p) { ::freeaddrinfo(p); });
|
||||
|
||||
struct ::addrinfo* nextAddr = addresses.get();
|
||||
int socket;
|
||||
while (true) {
|
||||
try {
|
||||
SYSCHECK_ERR_RETURN_NEG1(
|
||||
socket = ::socket(
|
||||
nextAddr->ai_family,
|
||||
nextAddr->ai_socktype,
|
||||
nextAddr->ai_protocol))
|
||||
SYSCHECK_ERR_RETURN_NEG1(tcputil::setSocketAddrReUse(socket))
|
||||
SYSCHECK_ERR_RETURN_NEG1(
|
||||
::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen))
|
||||
SYSCHECK_ERR_RETURN_NEG1(::listen(socket, LISTEN_QUEUE_SIZE))
|
||||
break;
|
||||
|
||||
} catch (const std::system_error& e) {
|
||||
tcputil::closeSocket(socket);
|
||||
nextAddr = nextAddr->ai_next;
|
||||
|
||||
// we have tried all addresses but could not start
|
||||
// listening on any of them
|
||||
if (!nextAddr) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// get listen port and address
|
||||
return {socket, getSocketPort(socket)};
|
||||
}
|
||||
|
||||
void handleConnectException(
|
||||
struct ::addrinfo** nextAddr,
|
||||
int error_code,
|
||||
bool* anyRefused,
|
||||
bool* anyReset,
|
||||
bool wait,
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> start,
|
||||
std::shared_ptr<struct ::addrinfo> addresses,
|
||||
std::chrono::milliseconds timeout) {
|
||||
// ECONNREFUSED happens if the server is not yet listening.
|
||||
if (error_code == ECONNREFUSED) {
|
||||
*anyRefused = true;
|
||||
}
|
||||
// ECONNRESET happens if the server's listen backlog is exhausted.
|
||||
if (error_code == ECONNRESET) {
|
||||
*anyReset = true;
|
||||
}
|
||||
|
||||
// We need to move to the next address because this was not available
|
||||
// to connect or to create a socket.
|
||||
*nextAddr = (*nextAddr)->ai_next;
|
||||
|
||||
// We have tried all addresses but could not connect to any of them.
|
||||
if (!*nextAddr) {
|
||||
if (!wait || (!anyRefused && !anyReset)) {
|
||||
throw;
|
||||
}
|
||||
|
||||
// if a timeout is specified, check time elapsed to see if we need to
|
||||
// timeout. A timeout is specified if timeout != kNoTimeout.
|
||||
if (timeout != kNoTimeout) {
|
||||
const auto elapsed = std::chrono::high_resolution_clock::now() - start;
|
||||
TORCH_CHECK(
|
||||
elapsed <= timeout,
|
||||
c10::str(
|
||||
kConnectTimeoutMsg,
|
||||
" Original timeout was ",
|
||||
timeout.count(),
|
||||
" ms."));
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
*anyRefused = false;
|
||||
*anyReset = false;
|
||||
*nextAddr = addresses.get();
|
||||
}
|
||||
}
|
||||
|
||||
void handleConnectSystemError(
|
||||
struct ::addrinfo** nextAddr,
|
||||
std::system_error& e,
|
||||
bool* anyRefused,
|
||||
bool* anyReset,
|
||||
bool wait,
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> start,
|
||||
std::shared_ptr<struct ::addrinfo> addresses,
|
||||
std::chrono::milliseconds timeout) {
|
||||
handleConnectException(
|
||||
nextAddr,
|
||||
e.code().value(),
|
||||
anyRefused,
|
||||
anyReset,
|
||||
wait,
|
||||
start,
|
||||
addresses,
|
||||
timeout);
|
||||
}
|
||||
|
||||
int connect(
|
||||
const std::string& address,
|
||||
PortType port,
|
||||
bool wait,
|
||||
const std::chrono::milliseconds& timeout) {
|
||||
struct ::addrinfo hints, *res = NULL;
|
||||
std::memset(&hints, 0x00, sizeof(hints));
|
||||
hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric
|
||||
hints.ai_family = AF_SELECTED; // IPv4 on Windows, IPv4/6 on Linux
|
||||
hints.ai_socktype = SOCK_STREAM; // TCP
|
||||
|
||||
// `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
|
||||
// by editing `/etc/gai.conf`. so there is no need to manual sorting
|
||||
// or protcol preference.
|
||||
int err =
|
||||
::getaddrinfo(address.data(), std::to_string(port).data(), &hints, &res);
|
||||
if (err != 0 || !res) {
|
||||
throw std::invalid_argument(
|
||||
"host not found: " + std::string(gai_strerror(err)));
|
||||
}
|
||||
|
||||
std::shared_ptr<struct ::addrinfo> addresses(
|
||||
res, [](struct ::addrinfo* p) { ::freeaddrinfo(p); });
|
||||
|
||||
struct ::addrinfo* nextAddr = addresses.get();
|
||||
int socket;
|
||||
|
||||
// Loop over the addresses if at least one of them gave us ECONNREFUSED
|
||||
// or ECONNRESET. This may happen if the server hasn't started listening
|
||||
// yet, or is listening but has its listen backlog exhausted.
|
||||
bool anyRefused = false;
|
||||
bool anyReset = false;
|
||||
const auto start = std::chrono::high_resolution_clock::now();
|
||||
while (true) {
|
||||
try {
|
||||
SYSCHECK_ERR_RETURN_NEG1(
|
||||
socket = ::socket(
|
||||
nextAddr->ai_family,
|
||||
nextAddr->ai_socktype,
|
||||
nextAddr->ai_protocol))
|
||||
|
||||
ResourceGuard socketGuard([socket]() { tcputil::closeSocket(socket); });
|
||||
|
||||
// We need to connect in non-blocking mode, so we can use a timeout
|
||||
waitSocketConnected(socket, nextAddr, timeout, start);
|
||||
|
||||
socketGuard.release();
|
||||
break;
|
||||
|
||||
} catch (std::system_error& e) {
|
||||
handleConnectSystemError(
|
||||
&nextAddr,
|
||||
e,
|
||||
&anyRefused,
|
||||
&anyReset,
|
||||
wait,
|
||||
start,
|
||||
addresses,
|
||||
timeout);
|
||||
} catch (std::exception& e) {
|
||||
handleConnectException(
|
||||
&nextAddr,
|
||||
errno,
|
||||
&anyRefused,
|
||||
&anyReset,
|
||||
wait,
|
||||
start,
|
||||
addresses,
|
||||
timeout);
|
||||
}
|
||||
}
|
||||
|
||||
setSocketNoDelay(socket);
|
||||
|
||||
return socket;
|
||||
}
|
||||
|
||||
std::tuple<int, std::string> accept(
|
||||
int listenSocket,
|
||||
const std::chrono::milliseconds& timeout) {
|
||||
// poll on listen socket, it allows to make timeout
|
||||
std::unique_ptr<struct ::pollfd[]> events(new struct ::pollfd[1]);
|
||||
events[0] = tcputil::getPollfd(listenSocket, POLLIN);
|
||||
|
||||
while (true) {
|
||||
int res = tcputil::poll(events.get(), 1, timeout.count());
|
||||
if (res == 0) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"waiting for processes to "
|
||||
"connect has timed out");
|
||||
} else if (res == -1) {
|
||||
if (errno == EINTR) {
|
||||
continue;
|
||||
}
|
||||
throw std::system_error(errno, std::system_category());
|
||||
} else {
|
||||
if (!(events[0].revents & POLLIN))
|
||||
throw std::system_error(ECONNABORTED, std::system_category());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int socket;
|
||||
SYSCHECK_ERR_RETURN_NEG1(socket = ::accept(listenSocket, NULL, NULL))
|
||||
|
||||
// Get address of the connecting process
|
||||
struct ::sockaddr_storage addr;
|
||||
socklen_t addrLen = sizeof(addr);
|
||||
SYSCHECK_ERR_RETURN_NEG1(::getpeername(
|
||||
socket, reinterpret_cast<struct ::sockaddr*>(&addr), &addrLen))
|
||||
|
||||
setSocketNoDelay(socket);
|
||||
|
||||
return std::make_tuple(
|
||||
socket, sockaddrToString(reinterpret_cast<struct ::sockaddr*>(&addr)));
|
||||
}
|
||||
} // namespace tcputil
|
||||
} // namespace c10d
|
||||
|
|
|
|||
|
|
@ -482,7 +482,6 @@ size_t computeLengthsAndOffsets(
|
|||
}
|
||||
|
||||
using RankType = uint32_t;
|
||||
using PortType = uint16_t;
|
||||
using SizeType = uint64_t;
|
||||
|
||||
// `errno` is only meaningful when it fails. E.g., a successful `fork()` sets
|
||||
|
|
@ -536,32 +535,8 @@ using SizeType = uint64_t;
|
|||
// Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1
|
||||
#define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1)
|
||||
|
||||
// Helper resource guard class
|
||||
class ResourceGuard {
|
||||
public:
|
||||
ResourceGuard(std::function<void()> destructor)
|
||||
: destructor_(std::move(destructor)), released_(false) {}
|
||||
|
||||
~ResourceGuard() {
|
||||
if (!released_) {
|
||||
destructor_();
|
||||
}
|
||||
}
|
||||
|
||||
void release() {
|
||||
released_ = true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<void()> destructor_;
|
||||
bool released_;
|
||||
};
|
||||
|
||||
namespace tcputil {
|
||||
|
||||
constexpr std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds(-1);
|
||||
const std::string kConnectTimeoutMsg = "connect() timed out.";
|
||||
|
||||
// Send and receive
|
||||
template <typename T>
|
||||
void sendBytes(
|
||||
|
|
@ -677,20 +652,5 @@ inline std::string recvString(int socket) {
|
|||
return std::string(value.data(), value.size());
|
||||
}
|
||||
|
||||
// Other helpers
|
||||
std::string sockaddrToString(struct sockaddr* addr);
|
||||
|
||||
std::pair<int, PortType> listen(PortType port);
|
||||
|
||||
int connect(
|
||||
const std::string& address,
|
||||
PortType port,
|
||||
bool wait = true,
|
||||
const std::chrono::milliseconds& timeout = kNoTimeout);
|
||||
|
||||
std::tuple<int, std::string> accept(
|
||||
int listenSocket,
|
||||
const std::chrono::milliseconds& timeout = kNoTimeout);
|
||||
|
||||
} // namespace tcputil
|
||||
} // namespace c10d
|
||||
|
|
|
|||
|
|
@ -5,17 +5,8 @@
|
|||
namespace c10d {
|
||||
namespace tcputil {
|
||||
|
||||
#define AF_SELECTED AF_INET
|
||||
#define CONNECT_SOCKET_OFFSET 1
|
||||
|
||||
inline void closeSocket(int socket) { ::closesocket(socket); }
|
||||
|
||||
inline int setSocketAddrReUse(int socket) {
|
||||
bool optval = false;
|
||||
return ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char *)&optval,
|
||||
sizeof(bool));
|
||||
}
|
||||
|
||||
inline int poll(struct pollfd *fdArray, unsigned long fds, int timeout) {
|
||||
return WSAPoll(fdArray, fds, timeout);
|
||||
}
|
||||
|
|
@ -25,62 +16,6 @@ inline void addPollfd(std::vector<struct pollfd> &fds, int socket,
|
|||
fds.push_back({(SOCKET)socket, events});
|
||||
}
|
||||
|
||||
inline void waitSocketConnected(
|
||||
int socket,
|
||||
struct ::addrinfo *nextAddr,
|
||||
std::chrono::milliseconds timeout,
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> startTime) {
|
||||
unsigned long block_mode = 1;
|
||||
SYSCHECK_ERR_RETURN_NEG1(ioctlsocket(socket, FIONBIO, &block_mode));
|
||||
|
||||
int ret;
|
||||
do {
|
||||
ret = connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen);
|
||||
if (ret == SOCKET_ERROR) {
|
||||
int err = WSAGetLastError();
|
||||
if (err == WSAEISCONN) {
|
||||
break;
|
||||
} else if (err == WSAEALREADY || err == WSAEWOULDBLOCK) {
|
||||
if (timeout != kNoTimeout) {
|
||||
const auto elapsed =
|
||||
std::chrono::high_resolution_clock::now() - startTime;
|
||||
if (elapsed > timeout) {
|
||||
errno = 0;
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
c10::str(
|
||||
kConnectTimeoutMsg,
|
||||
" Original timeout was ",
|
||||
timeout.count(),
|
||||
" ms."));
|
||||
}
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
continue;
|
||||
}
|
||||
throw std::system_error(err, std::system_category(),
|
||||
"Socket connect failed");
|
||||
}
|
||||
} while (ret == SOCKET_ERROR);
|
||||
|
||||
block_mode = 0;
|
||||
SYSCHECK_ERR_RETURN_NEG1(ioctlsocket(socket, FIONBIO, &block_mode));
|
||||
}
|
||||
|
||||
// All processes (applications or DLLs) that call Winsock
|
||||
// functions must initialize the use of the Windows Sockets
|
||||
// DLL before making other Winsock function calls.
|
||||
// This also makes certain that Winsock is supported on the system.
|
||||
// Ref to
|
||||
// https://docs.microsoft.com/en-us/windows/win32/winsock/initializing-winsock
|
||||
inline void socketInitialize() {
|
||||
static std::once_flag init_flag;
|
||||
std::call_once(init_flag, []() {
|
||||
WSADATA wsa_data;
|
||||
SYSCHECK_ERR_RETURN_NEG1(WSAStartup(MAKEWORD(2, 2), &wsa_data))
|
||||
});
|
||||
}
|
||||
|
||||
inline struct ::pollfd getPollfd(int socket, short events) {
|
||||
struct ::pollfd res = {(SOCKET)socket, events};
|
||||
return res;
|
||||
|
|
|
|||
41
torch/csrc/distributed/c10d/error.h
Normal file
41
torch/csrc/distributed/c10d/error.h
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <system_error>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace fmt {
|
||||
|
||||
template <>
|
||||
struct formatter<std::error_code> {
|
||||
constexpr decltype(auto) parse(format_parse_context& ctx) {
|
||||
return ctx.begin();
|
||||
}
|
||||
|
||||
template <typename FormatContext>
|
||||
decltype(auto) format(const std::error_code& err, FormatContext& ctx) {
|
||||
return format_to(ctx.out(),
|
||||
"({} error: {} - {})",
|
||||
err.category().name(),
|
||||
err.value(),
|
||||
err.message());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace fmt
|
||||
|
||||
namespace c10d {
|
||||
namespace detail {
|
||||
|
||||
inline std::error_code lastError() noexcept {
|
||||
return std::error_code{errno, std::generic_category()};
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace c10d
|
||||
9
torch/csrc/distributed/c10d/exception.cpp
Normal file
9
torch/csrc/distributed/c10d/exception.cpp
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
#include <c10d/exception.h>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
C10dError::~C10dError() = default;
|
||||
|
||||
TimeoutError::~TimeoutError() = default;
|
||||
|
||||
} // namespace c10d
|
||||
45
torch/csrc/distributed/c10d/exception.h
Normal file
45
torch/csrc/distributed/c10d/exception.h
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
class TORCH_API C10dError : public std::runtime_error {
|
||||
public:
|
||||
using std::runtime_error::runtime_error;
|
||||
|
||||
C10dError(const C10dError&) = default;
|
||||
|
||||
C10dError& operator=(const C10dError&) = default;
|
||||
|
||||
C10dError(C10dError&&) = default;
|
||||
|
||||
C10dError& operator=(C10dError&&) = default;
|
||||
|
||||
~C10dError() override;
|
||||
};
|
||||
|
||||
class TORCH_API TimeoutError : public C10dError {
|
||||
public:
|
||||
using C10dError::C10dError;
|
||||
|
||||
TimeoutError(const TimeoutError&) = default;
|
||||
|
||||
TimeoutError& operator=(const TimeoutError&) = default;
|
||||
|
||||
TimeoutError(TimeoutError&&) = default;
|
||||
|
||||
TimeoutError& operator=(TimeoutError&&) = default;
|
||||
|
||||
~TimeoutError() override;
|
||||
};
|
||||
|
||||
} // namespace c10d
|
||||
|
|
@ -911,7 +911,7 @@ Example::
|
|||
)")
|
||||
.def(
|
||||
py::init([](const std::string& host,
|
||||
::c10d::PortType port,
|
||||
uint16_t port,
|
||||
int worldSize,
|
||||
bool isServer,
|
||||
std::chrono::milliseconds timeout,
|
||||
|
|
|
|||
20
torch/csrc/distributed/c10d/logging.h
Normal file
20
torch/csrc/distributed/c10d/logging.h
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#define C10D_ERROR(...)\
|
||||
LOG_IF(ERROR, FLAGS_caffe2_log_level <= 2) << fmt::format(__VA_ARGS__)
|
||||
|
||||
#define C10D_WARNING(...)\
|
||||
LOG_IF(WARNING, FLAGS_caffe2_log_level <= 1) << fmt::format(__VA_ARGS__)
|
||||
|
||||
#define C10D_INFO(...)\
|
||||
LOG_IF(INFO, FLAGS_caffe2_log_level <= 0) << fmt::format(__VA_ARGS__)
|
||||
847
torch/csrc/distributed/c10d/socket.cpp
Normal file
847
torch/csrc/distributed/c10d/socket.cpp
Normal file
|
|
@ -0,0 +1,847 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include <c10d/socket.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <system_error>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <mutex>
|
||||
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
#else
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <poll.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include <fmt/chrono.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <c10d/error.h>
|
||||
#include <c10d/exception.h>
|
||||
#include <c10d/logging.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace detail {
|
||||
namespace {
|
||||
#ifdef _WIN32
|
||||
|
||||
// Since Winsock uses the name `WSAPoll` instead of `poll`, we alias it here
|
||||
// to avoid #ifdefs in the source code.
|
||||
const auto pollFd = ::WSAPoll;
|
||||
|
||||
// Winsock's `getsockopt()` and `setsockopt()` functions expect option values to
|
||||
// be passed as `char*` instead of `void*`. We wrap them here to avoid redundant
|
||||
// casts in the source code.
|
||||
int getSocketOption(SOCKET s, int level, int optname, void* optval, int* optlen) {
|
||||
return ::getsockopt(s, level, optname, static_cast<char*>(optval), optlen);
|
||||
}
|
||||
|
||||
int setSocketOption(SOCKET s, int level, int optname, const void* optval, int optlen) {
|
||||
return ::setsockopt(s, level, optname, static_cast<const char*>(optval), optlen);
|
||||
}
|
||||
|
||||
// Winsock has its own error codes which differ from Berkeley's. Fortunately the
|
||||
// C++ Standard Library on Windows can map them to standard error codes.
|
||||
inline std::error_code getSocketError() noexcept {
|
||||
return std::error_code{::WSAGetLastError(), std::system_category()};
|
||||
}
|
||||
|
||||
inline void setSocketError(int val) noexcept {
|
||||
::WSASetLastError(val);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
const auto pollFd = ::poll;
|
||||
|
||||
const auto getSocketOption = ::getsockopt;
|
||||
const auto setSocketOption = ::setsockopt;
|
||||
|
||||
inline std::error_code getSocketError() noexcept {
|
||||
return lastError();
|
||||
}
|
||||
|
||||
inline void setSocketError(int val) noexcept {
|
||||
errno = val;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// Suspends the current thread for the specified duration.
|
||||
void delay(std::chrono::seconds d) {
|
||||
#ifdef _WIN32
|
||||
std::this_thread::sleep_for(d);
|
||||
#else
|
||||
::timespec req{};
|
||||
req.tv_sec = d.count();
|
||||
|
||||
// The C++ Standard does not specify whether `sleep_for()` should be signal-
|
||||
// aware; therefore, we use the `nanosleep()` syscall.
|
||||
if (::nanosleep(&req, nullptr) != 0) {
|
||||
std::error_code err = getSocketError();
|
||||
// We don't care about error conditions other than EINTR since a failure
|
||||
// here is not critical.
|
||||
if (err == std::errc::interrupted) {
|
||||
throw std::system_error{err};
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
class SocketListenOp;
|
||||
class SocketConnectOp;
|
||||
} // namespace
|
||||
|
||||
class SocketImpl {
|
||||
friend class SocketListenOp;
|
||||
friend class SocketConnectOp;
|
||||
|
||||
public:
|
||||
#ifdef _WIN32
|
||||
using Handle = SOCKET;
|
||||
#else
|
||||
using Handle = int;
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
static constexpr Handle invalid_socket = INVALID_SOCKET;
|
||||
#else
|
||||
static constexpr Handle invalid_socket = -1;
|
||||
#endif
|
||||
|
||||
explicit SocketImpl(Handle hnd) noexcept
|
||||
: hnd_{hnd} {}
|
||||
|
||||
SocketImpl(const SocketImpl& other) = delete;
|
||||
|
||||
SocketImpl& operator=(const SocketImpl& other) = delete;
|
||||
|
||||
SocketImpl(SocketImpl&& other) noexcept = delete;
|
||||
|
||||
SocketImpl& operator=(SocketImpl&& other) noexcept = delete;
|
||||
|
||||
~SocketImpl();
|
||||
|
||||
std::unique_ptr<SocketImpl> accept() const;
|
||||
|
||||
void closeOnExec() noexcept;
|
||||
|
||||
void enableNonBlocking();
|
||||
|
||||
void disableNonBlocking();
|
||||
|
||||
bool enableNoDelay() noexcept;
|
||||
|
||||
bool enableDualStack() noexcept;
|
||||
|
||||
#ifndef _WIN32
|
||||
bool enableAddressReuse() noexcept;
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
bool enableExclusiveAddressUse() noexcept;
|
||||
#endif
|
||||
|
||||
std::uint16_t getPort() const;
|
||||
|
||||
Handle handle() const noexcept {
|
||||
return hnd_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool setSocketFlag(int level, int optname, bool value) noexcept;
|
||||
|
||||
Handle hnd_;
|
||||
};
|
||||
|
||||
SocketImpl::~SocketImpl() {
|
||||
#ifdef _WIN32
|
||||
::closesocket(hnd_);
|
||||
#else
|
||||
::close(hnd_);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<SocketImpl> SocketImpl::accept() const {
|
||||
::sockaddr_storage addr_s{};
|
||||
|
||||
auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s);
|
||||
|
||||
::socklen_t addr_len = sizeof(addr_s);
|
||||
|
||||
Handle hnd = ::accept(hnd_, addr_ptr, &addr_len);
|
||||
if (hnd == invalid_socket) {
|
||||
std::error_code err = getSocketError();
|
||||
if (err == std::errc::interrupted) {
|
||||
throw std::system_error{err};
|
||||
}
|
||||
|
||||
std::string msg{};
|
||||
if (err == std::errc::invalid_argument) {
|
||||
msg = fmt::format("The server socket on {} is not listening for connections.", *this);
|
||||
} else {
|
||||
msg = fmt::format("The server socket on {} has failed to accept a connection {}.", *this, err);
|
||||
}
|
||||
|
||||
C10D_ERROR(msg);
|
||||
|
||||
throw SocketError{msg};
|
||||
}
|
||||
|
||||
::addrinfo addr{};
|
||||
addr.ai_addr = addr_ptr;
|
||||
addr.ai_addrlen = addr_len;
|
||||
|
||||
C10D_INFO("The server socket on {} has accepted a connection from {}.", *this, addr);
|
||||
|
||||
auto impl = std::make_unique<SocketImpl>(hnd);
|
||||
|
||||
// Make sure that we do not "leak" our file descriptors to child processes.
|
||||
impl->closeOnExec();
|
||||
|
||||
if (!impl->enableNoDelay()) {
|
||||
C10D_WARNING("The no-delay option cannot be enabled for the client socket on {}.", addr);
|
||||
}
|
||||
|
||||
return impl;
|
||||
}
|
||||
|
||||
void SocketImpl::closeOnExec() noexcept {
|
||||
#ifndef _WIN32
|
||||
::fcntl(hnd_, F_SETFD, FD_CLOEXEC);
|
||||
#endif
|
||||
}
|
||||
|
||||
void SocketImpl::enableNonBlocking() {
|
||||
#ifdef _WIN32
|
||||
unsigned long value = 1;
|
||||
if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
int flg = ::fcntl(hnd_, F_GETFL);
|
||||
if (flg != -1) {
|
||||
if (::fcntl(hnd_, F_SETFL, flg | O_NONBLOCK) == 0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
throw SocketError{"The socket cannot be switched to non-blocking mode."};
|
||||
}
|
||||
|
||||
// TODO: Remove once we migrate everything to non-blocking mode.
|
||||
void SocketImpl::disableNonBlocking() {
|
||||
#ifdef _WIN32
|
||||
unsigned long value = 0;
|
||||
if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
int flg = ::fcntl(hnd_, F_GETFL);
|
||||
if (flg != -1) {
|
||||
if (::fcntl(hnd_, F_SETFL, flg & ~O_NONBLOCK) == 0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
throw SocketError{"The socket cannot be switched to blocking mode."};
|
||||
}
|
||||
|
||||
bool SocketImpl::enableNoDelay() noexcept {
|
||||
return setSocketFlag(IPPROTO_TCP, TCP_NODELAY, true);
|
||||
}
|
||||
|
||||
bool SocketImpl::enableDualStack() noexcept {
|
||||
return setSocketFlag(IPPROTO_IPV6, IPV6_V6ONLY, false);
|
||||
}
|
||||
|
||||
#ifndef _WIN32
|
||||
bool SocketImpl::enableAddressReuse() noexcept {
|
||||
return setSocketFlag(SOL_SOCKET, SO_REUSEADDR, true);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
bool SocketImpl::enableExclusiveAddressUse() noexcept {
|
||||
return setSocketFlag(SOL_SOCKET, SO_EXCLUSIVEADDRUSE, true);
|
||||
}
|
||||
#endif
|
||||
|
||||
std::uint16_t SocketImpl::getPort() const {
|
||||
::sockaddr_storage addr_s{};
|
||||
|
||||
::socklen_t addr_len = sizeof(addr_s);
|
||||
|
||||
if (::getsockname(hnd_, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) != 0) {
|
||||
throw SocketError{"The port number of the socket cannot be retrieved."};
|
||||
}
|
||||
|
||||
if (addr_s.ss_family == AF_INET) {
|
||||
return ntohs(reinterpret_cast<::sockaddr_in*> (&addr_s)->sin_port);
|
||||
} else {
|
||||
return ntohs(reinterpret_cast<::sockaddr_in6*>(&addr_s)->sin6_port);
|
||||
}
|
||||
}
|
||||
|
||||
bool SocketImpl::setSocketFlag(int level, int optname, bool value) noexcept {
|
||||
#ifdef _WIN32
|
||||
auto buf = value ? TRUE : FALSE;
|
||||
#else
|
||||
auto buf = value ? 1 : 0;
|
||||
#endif
|
||||
return setSocketOption(hnd_, level, optname, &buf, sizeof(buf)) == 0;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct addrinfo_delete {
|
||||
void operator()(::addrinfo* addr) const noexcept {
|
||||
::freeaddrinfo(addr);
|
||||
}
|
||||
};
|
||||
|
||||
using addrinfo_ptr = std::unique_ptr<::addrinfo, addrinfo_delete>;
|
||||
|
||||
class SocketListenOp {
|
||||
public:
|
||||
SocketListenOp(std::uint16_t port, const SocketOptions& opts);
|
||||
|
||||
std::unique_ptr<SocketImpl> run();
|
||||
|
||||
private:
|
||||
bool tryListen(int family);
|
||||
|
||||
bool tryListen(const ::addrinfo& addr);
|
||||
|
||||
template <typename... Args>
|
||||
void recordError(fmt::string_view format, Args&&... args) {
|
||||
auto msg = fmt::format(format, std::forward<Args>(args)...);
|
||||
|
||||
C10D_WARNING(msg);
|
||||
|
||||
errors_.emplace_back(std::move(msg));
|
||||
}
|
||||
|
||||
std::string port_;
|
||||
const SocketOptions* opts_;
|
||||
std::vector<std::string> errors_{};
|
||||
std::unique_ptr<SocketImpl> socket_{};
|
||||
};
|
||||
|
||||
SocketListenOp::SocketListenOp(std::uint16_t port, const SocketOptions& opts)
|
||||
: port_{fmt::to_string(port)}, opts_{&opts} {}
|
||||
|
||||
std::unique_ptr<SocketImpl> SocketListenOp::run() {
|
||||
if (opts_->prefer_ipv6()) {
|
||||
C10D_INFO("The server socket will attempt to listen on an IPv6 address.");
|
||||
if (tryListen(AF_INET6)) {
|
||||
return std::move(socket_);
|
||||
}
|
||||
|
||||
C10D_INFO("The server socket will attempt to listen on an IPv4 address.");
|
||||
if (tryListen(AF_INET)) {
|
||||
return std::move(socket_);
|
||||
}
|
||||
} else {
|
||||
C10D_INFO("The server socket will attempt to listen on an IPv4 or IPv6 address.");
|
||||
if (tryListen(AF_UNSPEC)) {
|
||||
return std::move(socket_);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto* msg = "The server socket has failed to listen on any local network address.";
|
||||
|
||||
C10D_ERROR(msg);
|
||||
|
||||
throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))};
|
||||
}
|
||||
|
||||
bool SocketListenOp::tryListen(int family) {
|
||||
::addrinfo hints{}, *naked_result = nullptr;
|
||||
|
||||
hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV;
|
||||
hints.ai_family = family;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int r = ::getaddrinfo(nullptr, port_.c_str(), &hints, &naked_result);
|
||||
if (r != 0) {
|
||||
const char* gai_err = ::gai_strerror(r);
|
||||
|
||||
recordError("The local {}network addresses cannot be retrieved (gai error: {} - {}).",
|
||||
family == AF_INET ? "IPv4 " : family == AF_INET6 ? "IPv6 " : "",
|
||||
r,
|
||||
gai_err);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
addrinfo_ptr result{naked_result};
|
||||
|
||||
for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) {
|
||||
C10D_INFO("The server socket is attempting to listen on {}.", *addr);
|
||||
if (tryListen(*addr)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool SocketListenOp::tryListen(const ::addrinfo& addr) {
|
||||
SocketImpl::Handle hnd = ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
|
||||
if (hnd == SocketImpl::invalid_socket) {
|
||||
recordError("The server socket cannot be initialized on {} {}.", addr, getSocketError());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
socket_ = std::make_unique<SocketImpl>(hnd);
|
||||
|
||||
#ifndef _WIN32
|
||||
if (!socket_->enableAddressReuse()) {
|
||||
C10D_WARNING("The address reuse option cannot be enabled for the server socket on {}.", addr);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
// The SO_REUSEADDR flag has a significantly different behavior on Windows
|
||||
// compared to Unix-like systems. It allows two or more processes to share
|
||||
// the same port simultaneously, which is totally unsafe.
|
||||
//
|
||||
// Here we follow the recommendation of Microsoft and use the non-standard
|
||||
// SO_EXCLUSIVEADDRUSE flag instead.
|
||||
if (!socket_->enableExclusiveAddressUse()) {
|
||||
C10D_WARNING("The exclusive address use option cannot be enabled for the server socket on {}.",
|
||||
addr);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Not all operating systems support dual-stack sockets by default. Since we
|
||||
// wish to use our IPv6 socket for IPv4 communication as well, we explicitly
|
||||
// ask the system to enable it.
|
||||
if (addr.ai_family == AF_INET6 && !socket_->enableDualStack()) {
|
||||
C10D_WARNING("The server socket does not support IPv4 communication on {}.", addr);
|
||||
}
|
||||
|
||||
if (::bind(socket_->handle(), addr.ai_addr, addr.ai_addrlen) != 0) {
|
||||
recordError("The server socket has failed to bind to {} {}.", addr, getSocketError());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-argument-comment)
|
||||
if (::listen(socket_->handle(), /*backlog=*/2048) != 0) {
|
||||
recordError("The server socket has failed to listen on {} {}.", addr, getSocketError());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
socket_->closeOnExec();
|
||||
|
||||
C10D_INFO("The server socket has started to listen on {}.", addr);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
class SocketConnectOp {
|
||||
using Clock = std::chrono::steady_clock;
|
||||
using Duration = std::chrono::steady_clock::duration;
|
||||
using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
|
||||
|
||||
static const std::chrono::seconds delay_duration_;
|
||||
|
||||
enum class ConnectResult {
|
||||
Success,
|
||||
Error,
|
||||
Retry,
|
||||
TimeOut
|
||||
};
|
||||
|
||||
public:
|
||||
SocketConnectOp(const std::string& host, std::uint16_t port, const SocketOptions& opts);
|
||||
|
||||
std::unique_ptr<SocketImpl> run();
|
||||
|
||||
private:
|
||||
bool tryConnect(int family);
|
||||
|
||||
ConnectResult tryConnect(const ::addrinfo& addr);
|
||||
|
||||
ConnectResult tryConnectCore(const ::addrinfo& addr);
|
||||
|
||||
template <typename... Args>
|
||||
void recordError(fmt::string_view format, Args&&... args) {
|
||||
auto msg = fmt::format(format, std::forward<Args>(args)...);
|
||||
|
||||
C10D_WARNING(msg);
|
||||
|
||||
errors_.emplace_back(std::move(msg));
|
||||
}
|
||||
|
||||
const char* host_;
|
||||
std::string port_;
|
||||
const SocketOptions* opts_;
|
||||
TimePoint deadline_{};
|
||||
std::vector<std::string> errors_{};
|
||||
std::unique_ptr<SocketImpl> socket_{};
|
||||
};
|
||||
|
||||
const std::chrono::seconds SocketConnectOp::delay_duration_{1};
|
||||
|
||||
SocketConnectOp::SocketConnectOp(const std::string& host,
|
||||
std::uint16_t port,
|
||||
const SocketOptions& opts)
|
||||
: host_{host.c_str()}, port_{fmt::to_string(port)}, opts_{&opts} {}
|
||||
|
||||
std::unique_ptr<SocketImpl> SocketConnectOp::run() {
|
||||
if (opts_->prefer_ipv6()) {
|
||||
C10D_INFO("The client socket will attempt to connect to an IPv6 address of ({}, {}).",
|
||||
host_,
|
||||
port_);
|
||||
|
||||
if (tryConnect(AF_INET6)) {
|
||||
return std::move(socket_);
|
||||
}
|
||||
|
||||
C10D_INFO("The client socket will attempt to connect to an IPv4 address of ({}, {}).",
|
||||
host_,
|
||||
port_);
|
||||
|
||||
if (tryConnect(AF_INET)) {
|
||||
return std::move(socket_);
|
||||
}
|
||||
} else {
|
||||
C10D_INFO("The client socket will attempt to connect to an IPv4 or IPv6 address of ({}, {}).",
|
||||
host_,
|
||||
port_);
|
||||
|
||||
if (tryConnect(AF_UNSPEC)) {
|
||||
return std::move(socket_);
|
||||
}
|
||||
}
|
||||
|
||||
auto msg = fmt::format(
|
||||
"The client socket has failed to connect to any network address of ({}, {}).",
|
||||
host_,
|
||||
port_);
|
||||
|
||||
C10D_ERROR(msg);
|
||||
|
||||
throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))};
|
||||
}
|
||||
|
||||
bool SocketConnectOp::tryConnect(int family) {
|
||||
::addrinfo hints{}, *naked_result = nullptr;
|
||||
|
||||
hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV;
|
||||
hints.ai_family = family;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int r = ::getaddrinfo(host_, port_.c_str(), &hints, &naked_result);
|
||||
if (r != 0) {
|
||||
const char* gai_err = ::gai_strerror(r);
|
||||
|
||||
recordError("The {}network addresses of ({}, {}) cannot be retrieved (gai error: {} - {}).",
|
||||
family == AF_INET ? "IPv4 " : family == AF_INET6 ? "IPv6 " : "",
|
||||
host_,
|
||||
port_,
|
||||
r,
|
||||
gai_err);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
addrinfo_ptr result{naked_result};
|
||||
|
||||
deadline_ = Clock::now() + opts_->connect_timeout();
|
||||
|
||||
bool retry; // NOLINT(cppcoreguidelines-init-variables)
|
||||
do {
|
||||
retry = false;
|
||||
|
||||
errors_.clear();
|
||||
|
||||
for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) {
|
||||
C10D_INFO("The client socket is attempting to connect to {}.", *addr);
|
||||
|
||||
ConnectResult cr = tryConnect(*addr);
|
||||
if (cr == ConnectResult::Success) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cr == ConnectResult::TimeOut) {
|
||||
auto msg = fmt::format(
|
||||
"The client socket has timed out after {} while trying to connect to ({}, {}).",
|
||||
opts_->connect_timeout(),
|
||||
host_,
|
||||
port_);
|
||||
|
||||
C10D_ERROR(msg);
|
||||
|
||||
throw TimeoutError{msg};
|
||||
}
|
||||
|
||||
if (cr == ConnectResult::Retry) {
|
||||
retry = true;
|
||||
}
|
||||
}
|
||||
} while (retry);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
SocketConnectOp::ConnectResult SocketConnectOp::tryConnect(const ::addrinfo& addr) {
|
||||
if (Clock::now() >= deadline_) {
|
||||
return ConnectResult::TimeOut;
|
||||
}
|
||||
|
||||
SocketImpl::Handle hnd = ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
|
||||
if (hnd == SocketImpl::invalid_socket) {
|
||||
recordError("The client socket cannot be initialized to connect to {} {}.",
|
||||
addr,
|
||||
getSocketError());
|
||||
|
||||
return ConnectResult::Error;
|
||||
}
|
||||
|
||||
socket_ = std::make_unique<SocketImpl>(hnd);
|
||||
|
||||
socket_->enableNonBlocking();
|
||||
|
||||
ConnectResult cr = tryConnectCore(addr);
|
||||
if (cr == ConnectResult::Error) {
|
||||
std::error_code err = getSocketError();
|
||||
if (err == std::errc::interrupted) {
|
||||
throw std::system_error{err};
|
||||
}
|
||||
|
||||
// Retry if the server is not yet listening or if its backlog is exhausted.
|
||||
if (err == std::errc::connection_refused || err == std::errc::connection_reset) {
|
||||
C10D_WARNING("The server socket on {} is not yet listening {}.", addr, err);
|
||||
|
||||
if (Clock::now() < deadline_ - delay_duration_) {
|
||||
// Wait a little to avoid choking the server.
|
||||
delay(delay_duration_);
|
||||
|
||||
return ConnectResult::Retry;
|
||||
} else {
|
||||
return ConnectResult::TimeOut;
|
||||
}
|
||||
} else {
|
||||
recordError("The client socket has failed to connect to {} {}.", addr, err);
|
||||
|
||||
return ConnectResult::Error;
|
||||
}
|
||||
}
|
||||
|
||||
if (cr == ConnectResult::TimeOut) {
|
||||
return cr;
|
||||
}
|
||||
|
||||
socket_->closeOnExec();
|
||||
|
||||
// TODO: Remove once we fully migrate to non-blocking mode.
|
||||
socket_->disableNonBlocking();
|
||||
|
||||
C10D_INFO("The client socket has connected to {} on {}.", addr, *socket_);
|
||||
|
||||
if (!socket_->enableNoDelay()) {
|
||||
C10D_WARNING("The no-delay option cannot be enabled for the client socket on {}.", *socket_);
|
||||
}
|
||||
|
||||
return ConnectResult::Success;
|
||||
}
|
||||
|
||||
SocketConnectOp::ConnectResult SocketConnectOp::tryConnectCore(const ::addrinfo& addr) {
|
||||
int r = ::connect(socket_->handle(), addr.ai_addr, addr.ai_addrlen);
|
||||
if (r == 0) {
|
||||
return ConnectResult::Success;
|
||||
}
|
||||
|
||||
std::error_code err = getSocketError();
|
||||
if (err == std::errc::already_connected) {
|
||||
return ConnectResult::Success;
|
||||
}
|
||||
|
||||
if (err != std::errc::operation_in_progress && err != std::errc::operation_would_block) {
|
||||
return ConnectResult::Error;
|
||||
}
|
||||
|
||||
Duration remaining = deadline_ - Clock::now();
|
||||
if (remaining <= Duration::zero()) {
|
||||
return ConnectResult::TimeOut;
|
||||
}
|
||||
|
||||
::pollfd pfd{};
|
||||
pfd.fd = socket_->handle();
|
||||
pfd.events = POLLOUT;
|
||||
|
||||
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(remaining);
|
||||
|
||||
r = pollFd(&pfd, 1, static_cast<int>(ms.count()));
|
||||
if (r == 0) {
|
||||
return ConnectResult::TimeOut;
|
||||
}
|
||||
if (r == -1) {
|
||||
return ConnectResult::Error;
|
||||
}
|
||||
|
||||
int err_code = 0;
|
||||
|
||||
::socklen_t err_len = sizeof(int);
|
||||
|
||||
r = getSocketOption(socket_->handle(), SOL_SOCKET, SO_ERROR, &err_code, &err_len);
|
||||
if (r != 0) {
|
||||
return ConnectResult::Error;
|
||||
}
|
||||
|
||||
if (err_code != 0) {
|
||||
setSocketError(err_code);
|
||||
|
||||
return ConnectResult::Error;
|
||||
} else {
|
||||
return ConnectResult::Success;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Socket::initialize() {
|
||||
#ifdef _WIN32
|
||||
static std::once_flag init_flag{};
|
||||
|
||||
// All processes that call socket functions on Windows must first initialize
|
||||
// the Winsock library.
|
||||
std::call_once(init_flag, []() {
|
||||
WSADATA data{};
|
||||
if (::WSAStartup(MAKEWORD(2, 2), &data) != 0) {
|
||||
throw SocketError{"The initialization of Winsock has failed."};
|
||||
}
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
Socket Socket::listen(std::uint16_t port, const SocketOptions& opts) {
|
||||
SocketListenOp op{port, opts};
|
||||
|
||||
return Socket{op.run()};
|
||||
}
|
||||
|
||||
Socket Socket::connect(const std::string& host, std::uint16_t port, const SocketOptions& opts) {
|
||||
SocketConnectOp op{host, port, opts};
|
||||
|
||||
return Socket{op.run()};
|
||||
}
|
||||
|
||||
Socket::Socket(Socket&& other) noexcept = default;
|
||||
|
||||
Socket& Socket::operator=(Socket&& other) noexcept = default;
|
||||
|
||||
Socket::~Socket() = default;
|
||||
|
||||
Socket Socket::accept() const {
|
||||
if (impl_) {
|
||||
return Socket{impl_->accept()};
|
||||
}
|
||||
|
||||
throw SocketError{"The socket is not initialized."};
|
||||
}
|
||||
|
||||
int Socket::handle() const noexcept {
|
||||
if (impl_) {
|
||||
return impl_->handle();
|
||||
}
|
||||
return SocketImpl::invalid_socket;
|
||||
}
|
||||
|
||||
std::uint16_t Socket::port() const {
|
||||
if (impl_) {
|
||||
return impl_->getPort();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
Socket::Socket(std::unique_ptr<SocketImpl>&& impl) noexcept
|
||||
: impl_{std::move(impl)} {}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
SocketError::~SocketError() = default;
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
//
|
||||
// libfmt formatters for `addrinfo` and `Socket`
|
||||
//
|
||||
namespace fmt {
|
||||
|
||||
template <>
|
||||
struct formatter<::addrinfo> {
|
||||
constexpr decltype(auto) parse(format_parse_context& ctx) {
|
||||
return ctx.begin();
|
||||
}
|
||||
|
||||
template <typename FormatContext>
|
||||
decltype(auto) format(const ::addrinfo& addr, FormatContext& ctx) {
|
||||
char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT
|
||||
|
||||
int r = ::getnameinfo(addr.ai_addr,
|
||||
addr.ai_addrlen,
|
||||
host,
|
||||
NI_MAXHOST,
|
||||
port,
|
||||
NI_MAXSERV,
|
||||
NI_NUMERICSERV);
|
||||
if (r != 0) {
|
||||
return format_to(ctx.out(), "?UNKNOWN?");
|
||||
}
|
||||
|
||||
if (addr.ai_addr->sa_family == AF_INET) {
|
||||
return format_to(ctx.out(), "{}:{}", host, port);
|
||||
} else {
|
||||
return format_to(ctx.out(), "[{}]:{}", host, port);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct formatter<c10d::detail::SocketImpl> {
|
||||
constexpr decltype(auto) parse(format_parse_context& ctx) {
|
||||
return ctx.begin();
|
||||
}
|
||||
|
||||
template <typename FormatContext>
|
||||
decltype(auto) format(const c10d::detail::SocketImpl& socket, FormatContext& ctx) {
|
||||
::sockaddr_storage addr_s{};
|
||||
|
||||
auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s);
|
||||
|
||||
::socklen_t addr_len = sizeof(addr_s);
|
||||
|
||||
if (::getsockname(socket.handle(), addr_ptr, &addr_len) != 0) {
|
||||
return format_to(ctx.out(), "?UNKNOWN?");
|
||||
}
|
||||
|
||||
::addrinfo addr{};
|
||||
addr.ai_addr = addr_ptr;
|
||||
addr.ai_addrlen = addr_len;
|
||||
|
||||
return format_to(ctx.out(), "{}", addr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace fmt
|
||||
100
torch/csrc/distributed/c10d/socket.h
Normal file
100
torch/csrc/distributed/c10d/socket.h
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10d/exception.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace detail {
|
||||
|
||||
class SocketOptions {
|
||||
public:
|
||||
SocketOptions& prefer_ipv6(bool value) noexcept {
|
||||
prefer_ipv6_ = value;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool prefer_ipv6() const noexcept {
|
||||
return prefer_ipv6_;
|
||||
}
|
||||
|
||||
SocketOptions& connect_timeout(std::chrono::seconds value) noexcept {
|
||||
connect_timeout_ = value;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::chrono::seconds connect_timeout() const noexcept {
|
||||
return connect_timeout_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool prefer_ipv6_ = true;
|
||||
std::chrono::seconds connect_timeout_{30};
|
||||
};
|
||||
|
||||
class SocketImpl;
|
||||
|
||||
class Socket {
|
||||
public:
|
||||
// This function initializes the underlying socket library and must be called
|
||||
// before any other socket function.
|
||||
static void initialize();
|
||||
|
||||
static Socket listen(std::uint16_t port, const SocketOptions& opts = {});
|
||||
|
||||
static Socket connect(const std::string& host, std::uint16_t port, const SocketOptions& opts = {});
|
||||
|
||||
Socket() noexcept = default;
|
||||
|
||||
Socket(const Socket& other) = delete;
|
||||
|
||||
Socket& operator=(const Socket& other) = delete;
|
||||
|
||||
Socket(Socket&& other) noexcept;
|
||||
|
||||
Socket& operator=(Socket&& other) noexcept;
|
||||
|
||||
~Socket();
|
||||
|
||||
Socket accept() const;
|
||||
|
||||
int handle() const noexcept;
|
||||
|
||||
std::uint16_t port() const;
|
||||
|
||||
private:
|
||||
explicit Socket(std::unique_ptr<SocketImpl>&& impl) noexcept;
|
||||
|
||||
std::unique_ptr<SocketImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
class TORCH_API SocketError : public C10dError {
|
||||
public:
|
||||
using C10dError::C10dError;
|
||||
|
||||
SocketError(const SocketError&) = default;
|
||||
|
||||
SocketError& operator=(const SocketError&) = default;
|
||||
|
||||
SocketError(SocketError&&) = default;
|
||||
|
||||
SocketError& operator=(SocketError&&) = default;
|
||||
|
||||
~SocketError() override;
|
||||
};
|
||||
|
||||
} // namespace c10d
|
||||
|
|
@ -111,7 +111,7 @@ class C10dRendezvousBackend(RendezvousBackend):
|
|||
def _call_store(self, store_op: str, *args, **kwargs) -> Any:
|
||||
try:
|
||||
return getattr(self._store, store_op)(*args, **kwargs)
|
||||
except (ValueError, RuntimeError) as exc:
|
||||
except (ValueError, RuntimeError, TimeoutError) as exc:
|
||||
raise RendezvousConnectionError(
|
||||
"The connection to the C10d store has failed. See inner exception for details."
|
||||
) from exc
|
||||
|
|
@ -164,7 +164,7 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
|
|||
log.info(msg)
|
||||
|
||||
break
|
||||
except (ValueError, RuntimeError) as exc:
|
||||
except (ValueError, RuntimeError, TimeoutError) as exc:
|
||||
# If we heuristically inferred the value of is_host as True and our
|
||||
# first attempt to instantiate the TCP store has failed, try it one
|
||||
# more time with is_host set to False. As an edge case there can be
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from torch.distributed.elastic.utils.logging import get_logger
|
|||
log = get_logger()
|
||||
|
||||
_ADDRESS_IN_USE = "Address already in use"
|
||||
_CONNECT_TIMEOUT = "connect() timed out."
|
||||
_SOCKET_TIMEOUT = "Socket Timeout"
|
||||
|
||||
_MEMBER_CHECKIN = "_tcp_store/num_members"
|
||||
|
|
@ -75,18 +74,14 @@ def create_c10d_store(
|
|||
# detects timeouts and port conflicts in their own unittests
|
||||
# see - caffe2/torch/testing/_internal/common_utils.py
|
||||
# TODO properly map the exceptions in pybind (c10d/init.cpp)
|
||||
if _CONNECT_TIMEOUT in str(e) and not is_server:
|
||||
raise TimeoutError(
|
||||
f"timed out waiting for tcp store's server: {server_addr}:{port}"
|
||||
) from e
|
||||
elif str(e) == _ADDRESS_IN_USE: # this will only happen on the server
|
||||
if str(e) == _ADDRESS_IN_USE: # this will only happen on the server
|
||||
if attempt < retries:
|
||||
log.warning(
|
||||
f"port: {port} already in use, attempt: [{attempt}/{retries}]"
|
||||
)
|
||||
attempt += 1
|
||||
else:
|
||||
raise IOError(
|
||||
raise RuntimeError(
|
||||
f"on {server_addr}, port: {port} already in use"
|
||||
) from e
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user