Revise the socket implementation of c10d (#68226)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68226

**Note that this PR is unusually big due to the urgency of the changes. Please reach out to me in case you wish to have a "pair" review.**

This PR introduces a major refactoring of the socket implementation of the C10d library. A big portion of the logic is now contained in the `Socket` class and a follow-up PR will further consolidate the remaining parts. As of today the changes in this PR offer:

 - significantly better error handling and much more verbose logging (see the example output below)
 - explicit support for IPv6 and dual-stack sockets
 - correct handling of signal interrupts
 - better Windows support

A follow-up PR will consolidate `send`/`recv` logic into `Socket` and fully migrate to non-blocking sockets.

## Example Output

```
[I logging.h:21] The client socket will attempt to connect to an IPv6 address on (127.0.0.1, 29501).
[I logging.h:21] The client socket is attempting to connect to [localhost]:29501.
[W logging.h:28] The server socket on [localhost]:29501 is not yet listening (Error: 111 - Connection refused), retrying...
[I logging.h:21] The server socket will attempt to listen on an IPv6 address.
[I logging.h:21] The server socket is attempting to listen on [::]:29501.
[I logging.h:21] The server socket has started to listen on [::]:29501.
[I logging.h:21] The client socket will attempt to connect to an IPv6 address on (127.0.0.1, 29501).
[I logging.h:21] The client socket is attempting to connect to [localhost]:29501.
[I logging.h:21] The client socket has connected to [localhost]:29501 on [localhost]:42650.
[I logging.h:21] The server socket on [::]:29501 has accepted a connection from [localhost]:42650.
[I logging.h:21] The client socket has connected to [localhost]:29501 on [localhost]:42722.
[I logging.h:21] The server socket on [::]:29501 has accepted a connection from [localhost]:42722.
[I logging.h:21] The client socket will attempt to connect to an IPv6 address on (127.0.0.1, 29501).
[I logging.h:21] The client socket is attempting to connect to [localhost]:29501.
[I logging.h:21] The client socket has connected to [localhost]:29501 on [localhost]:42724.
[I logging.h:21] The server socket on [::]:29501 has accepted a connection from [localhost]:42724.
[I logging.h:21] The client socket will attempt to connect to an IPv6 address on (127.0.0.1, 29501).
[I logging.h:21] The client socket is attempting to connect to [localhost]:29501.
[I logging.h:21] The client socket has connected to [localhost]:29501 on [localhost]:42726.
[I logging.h:21] The server socket on [::]:29501 has accepted a connection from [localhost]:42726.
```
ghstack-source-id: 143501987

Test Plan: Run existing unit and integration tests on devserver, Fedora, Ubuntu, macOS Big Sur, Windows 10.

Reviewed By: Babar, wilson100hong, mrshenli

Differential Revision: D32372333

fbshipit-source-id: 2204ffa28ed0d3683a9cb3ebe1ea8d92a831325a
This commit is contained in:
Can Balioglu 2021-11-16 20:47:57 -08:00 committed by Facebook GitHub Bot
parent 4c346bd073
commit 6e640a0acf
19 changed files with 1130 additions and 572 deletions

View File

@ -1005,6 +1005,7 @@ if __name__ == '__main__':
'include/torch/csrc/autograd/generated/*.h',
'include/torch/csrc/autograd/utils/*.h',
'include/torch/csrc/cuda/*.h',
'include/torch/csrc/distributed/c10d/exception.h',
'include/torch/csrc/jit/*.h',
'include/torch/csrc/jit/backends/*.h',
'include/torch/csrc/jit/generated/*.h',

View File

@ -131,7 +131,7 @@ class DistributedUtilTest(TestCase):
server_port=pick_free_port,
timeout=1,
)
with self.assertRaises(IOError):
with self.assertRaises(RuntimeError):
create_c10d_store(
is_server=True, server_addr=server_addr, server_port=store1.port
)
@ -142,7 +142,7 @@ class DistributedUtilTest(TestCase):
port = sock.getsockname()[1]
# on the worker port conflict shouldn't matter, it should just timeout
# since we never created a server
with self.assertRaises(IOError):
with self.assertRaises(TimeoutError):
create_c10d_store(
is_server=False,
server_addr=socket.gethostname(),

View File

@ -157,10 +157,7 @@ class TCPStoreTest(TestCase, StoreTestBase):
return store
def test_address_already_in_use(self):
if sys.platform == "win32":
err_msg_reg = "Only one usage of each socket address*"
else:
err_msg_reg = "^Address already in use$"
err_msg_reg = "^The server socket has failed to listen on any local "
with self.assertRaisesRegex(RuntimeError, err_msg_reg):
addr = DEFAULT_HOSTNAME
port = common.find_free_port()

View File

@ -384,24 +384,26 @@ libtorch_core_sources = sorted(
# These files are the only ones that are supported on Windows.
libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/frontend.cpp",
"torch/csrc/distributed/c10d/comm.cpp",
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
"torch/csrc/distributed/c10d/logger.cpp",
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
"torch/csrc/distributed/c10d/PrefixStore.cpp",
"torch/csrc/distributed/c10d/ProcessGroup.cpp",
"torch/csrc/distributed/c10d/ProcessGroupGloo.cpp",
"torch/csrc/distributed/c10d/ProcessGroupMPI.cpp",
"torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp",
"torch/csrc/distributed/c10d/quantization/quantization.cpp",
"torch/csrc/distributed/c10d/reducer.cpp",
"torch/csrc/distributed/c10d/sequence_num.cpp",
"torch/csrc/distributed/c10d/Store.cpp",
"torch/csrc/distributed/c10d/TCPStore.cpp",
"torch/csrc/distributed/c10d/Utils.cpp",
"torch/csrc/distributed/c10d/comm.cpp",
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
"torch/csrc/distributed/c10d/exception.cpp",
"torch/csrc/distributed/c10d/frontend.cpp",
"torch/csrc/distributed/c10d/logger.cpp",
"torch/csrc/distributed/c10d/quantization/quantization.cpp",
"torch/csrc/distributed/c10d/reducer.cpp",
"torch/csrc/distributed/c10d/sequence_num.cpp",
"torch/csrc/distributed/c10d/socket.cpp",
]
# These files are only supported on Linux (and others) but not on Windows.

View File

@ -2,6 +2,7 @@
#include <exception>
#include <string>
#include <system_error>
#include <memory>
#include <queue>
#include <mutex>
@ -15,6 +16,10 @@
#include <c10/util/StringUtil.h>
#include <ATen/detail/FunctionTraits.h>
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
#include <torch/csrc/distributed/c10d/exception.h>
#endif
static inline void PyErr_SetString(PyObject* type, const std::string& message) {
PyErr_SetString(type, message.c_str());
}
@ -48,7 +53,7 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
try {
// Only catch torch-specific exceptions
#define CATCH_TH_ERRORS(retstmnt) \
#define CATCH_CORE_ERRORS(retstmnt) \
catch (python_error & e) { \
e.restore(); \
retstmnt; \
@ -83,12 +88,27 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
PyErr_SetString(PyExc_RuntimeError, torch::processErrorMsg(msg)); \
retstmnt; \
} \
catch (torch::PyTorchError & e) { \
catch (torch::PyTorchError& e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(e.python_type(), msg); \
retstmnt; \
}
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
#define CATCH_C10D_ERRORS(retstmnt) \
catch (const c10d::TimeoutError& e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(PyExc_TimeoutError, msg); \
retstmnt; \
}
#else
#define CATCH_C10D_ERRORS(retstmnt)
#endif
#define CATCH_TH_ERRORS(retstmnt) \
CATCH_CORE_ERRORS(retstmnt) \
CATCH_C10D_ERRORS(retstmnt)
#define CATCH_ALL_ERRORS(retstmnt) \
CATCH_TH_ERRORS(retstmnt) \
catch (const std::exception& e) { \

View File

@ -23,66 +23,27 @@
#include <c10d/UnixSockUtils.hpp>
#endif
#include <c10d/socket.h>
namespace c10d {
namespace detail {
namespace {
// Offers RAII for TCP sockets.
class TCPSocket {
public:
TCPSocket() noexcept = default;
/* implicit */ TCPSocket(int handle) noexcept : handle_{handle} {}
TCPSocket(const TCPSocket& other) = delete;
TCPSocket& operator=(const TCPSocket& other) = delete;
TCPSocket(TCPSocket&& other) noexcept : handle_{other.handle_} {
other.handle_ = c10::nullopt;
}
TCPSocket& operator=(TCPSocket&& other) noexcept {
closeSocket();
handle_ = std::exchange(other.handle_, c10::nullopt);
return *this;
}
~TCPSocket() {
closeSocket();
}
int handle() const noexcept {
return handle_.value_or(-1);
}
private:
void closeSocket() noexcept {
if (handle_) {
tcputil::closeSocket(*handle_);
}
}
c10::optional<int> handle_{};
};
// Abstract base class to handle thread state for TCPStoreMasterDaemon and
// TCPStoreWorkerDaemon. Contains the windows/unix implementations to signal a
// shutdown sequence for the thread
class BackgroundThread {
public:
explicit BackgroundThread(TCPSocket&& storeListenSocket);
explicit BackgroundThread(Socket&& storeListenSocket);
virtual ~BackgroundThread() = 0;
protected:
void dispose();
TCPSocket storeListenSocket_;
Socket storeListenSocket_;
std::thread daemonThread_{};
std::vector<TCPSocket> sockets_{};
std::vector<Socket> sockets_{};
#ifdef _WIN32
const std::chrono::milliseconds checkTimeout_ = std::chrono::milliseconds{10};
HANDLE ghStopEvent_{};
@ -102,7 +63,7 @@ class BackgroundThread {
};
// Background thread parent class methods
BackgroundThread::BackgroundThread(TCPSocket&& storeListenSocket)
BackgroundThread::BackgroundThread(Socket&& storeListenSocket)
: storeListenSocket_{std::move(storeListenSocket)} {
// Signal instance destruction to the daemon thread.
initStopSignal();
@ -200,7 +161,7 @@ enum class WatchResponseType : uint8_t {
// Separate thread that is only launched on master
class TCPStoreMasterDaemon : public BackgroundThread {
public:
explicit TCPStoreMasterDaemon(TCPSocket&& storeListenSocket);
explicit TCPStoreMasterDaemon(Socket&& storeListenSocket);
~TCPStoreMasterDaemon() override;
@ -241,7 +202,7 @@ class TCPStoreMasterDaemon : public BackgroundThread {
};
// Simply start the daemon thread
TCPStoreMasterDaemon::TCPStoreMasterDaemon(TCPSocket&& storeListenSocket)
TCPStoreMasterDaemon::TCPStoreMasterDaemon(Socket&& storeListenSocket)
: BackgroundThread{std::move(storeListenSocket)} {
daemonThread_ = std::thread{&TCPStoreMasterDaemon::run, this};
}
@ -270,7 +231,6 @@ void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
// exception, other connections will get an exception once they try to
// use the store. We will go ahead and close this connection whenever
// we hit an exception here.
tcputil::closeSocket(fds[fdIdx].fd);
// Remove all the tracking state of the close FD
for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
@ -562,8 +522,7 @@ void TCPStoreMasterDaemon::run() {
"Unexpected poll revent on the master's listening socket: " +
std::to_string(fds[0].revents));
}
TCPSocket socket =
std::get<0>(tcputil::accept(storeListenSocket_.handle()));
Socket socket = storeListenSocket_.accept();
int rawSocket = socket.handle();
sockets_.emplace_back(std::move(socket));
tcputil::addPollfd(fds, rawSocket, POLLIN);
@ -597,8 +556,7 @@ void TCPStoreMasterDaemon::run() {
"Unexpected poll revent on the master's listening socket: " +
std::to_string(fds[0].revents));
}
TCPSocket socket =
std::get<0>(tcputil::accept(storeListenSocket_.handle()));
Socket socket = storeListenSocket_.accept();
int rawSocket = socket.handle();
sockets_.emplace_back(std::move(socket));
tcputil::addPollfd(fds, rawSocket, POLLIN);
@ -626,7 +584,7 @@ void TCPStoreMasterDaemon::run() {
// Right now only handles callbacks registered from watchKey()
class TCPStoreWorkerDaemon : public BackgroundThread {
public:
explicit TCPStoreWorkerDaemon(TCPSocket&& listenSocket);
explicit TCPStoreWorkerDaemon(Socket&& listenSocket);
~TCPStoreWorkerDaemon() override;
// Set the callback to run key change
void setCallback(std::string key, WatchKeyCallback cb);
@ -657,7 +615,7 @@ class TCPStoreWorkerDaemon : public BackgroundThread {
};
// TCPStoreListener class methods
TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(TCPSocket&& listenSocket)
TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(Socket&& listenSocket)
: BackgroundThread{std::move(listenSocket)} {
daemonThread_ = std::thread{&TCPStoreWorkerDaemon::run, this};
}
@ -806,10 +764,9 @@ std::mutex TCPServer::cache_mutex_{};
std::shared_ptr<TCPServer> TCPServer::start(const TCPStoreOptions& opts) {
auto startCore = [&opts]() {
TCPSocket socket{};
std::uint16_t port{};
Socket socket = Socket::listen(opts.port);
std::tie(socket, port) = tcputil::listen(opts.port);
std::uint16_t port = socket.port();
auto daemon = std::make_unique<TCPStoreMasterDaemon>(std::move(socket));
@ -881,17 +838,19 @@ class TCPClient {
void setTimeout(std::chrono::milliseconds value);
explicit TCPClient(TCPSocket&& socket) : socket_{std::move(socket)} {}
explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {}
private:
TCPSocket socket_;
Socket socket_;
};
std::unique_ptr<TCPClient> TCPClient::connect(
const SocketAddress& addr,
const TCPStoreOptions& opts) {
TCPSocket socket =
tcputil::connect(addr.host, addr.port, /* wait */ true, opts.timeout);
auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
Socket socket = Socket::connect(addr.host,
addr.port,
SocketOptions{}.connect_timeout(timeout));
return std::make_unique<TCPClient>(std::move(socket));
}
@ -962,8 +921,10 @@ class TCPCallbackClient {
std::unique_ptr<TCPCallbackClient> TCPCallbackClient::connect(
const SocketAddress& addr,
const TCPStoreOptions& opts) {
TCPSocket socket =
tcputil::connect(addr.host, addr.port, /*wait*/ true, opts.timeout);
auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
Socket socket = Socket::connect(addr.host,
addr.port,
SocketOptions{}.connect_timeout(timeout));
int rawSocket = socket.handle();
@ -988,6 +949,8 @@ void TCPCallbackClient::setCallback(
} // namespace detail
using detail::Socket;
// TCPStore class methods
TCPStore::TCPStore(
const std::string& masterAddr,
@ -1010,7 +973,7 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
: Store{opts.timeout},
addr_{std::move(host)},
numWorkers_{opts.numWorkers} {
tcputil::socketInitialize();
Socket::initialize();
if (opts.isServer) {
server_ = detail::TCPServer::start(opts);

View File

@ -5,16 +5,8 @@
namespace c10d {
namespace tcputil {
#define AF_SELECTED AF_UNSPEC
#define CONNECT_SOCKET_OFFSET 2
inline void closeSocket(int socket) { ::close(socket); }
inline int setSocketAddrReUse(int socket) {
int optval = 1;
return ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int));
}
inline int poll(struct pollfd *fds, unsigned long nfds, int timeout) {
return ::poll(fds, nfds, timeout);
}
@ -24,70 +16,6 @@ inline void addPollfd(std::vector<struct pollfd> &fds, int socket,
fds.push_back({.fd = socket, .events = events});
}
inline void waitSocketConnected(
int socket,
struct ::addrinfo *nextAddr,
std::chrono::milliseconds timeout,
std::chrono::time_point<std::chrono::high_resolution_clock> startTime) {
SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, O_NONBLOCK));
int ret = ::connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen);
if (ret != 0 && errno != EINPROGRESS) {
throw std::system_error(errno, std::system_category());
}
struct ::pollfd pfd;
pfd.fd = socket;
pfd.events = POLLOUT;
int64_t pollTimeout = -1;
if (timeout != kNoTimeout) {
// calculate remaining time and use that as timeout for poll()
const auto elapsed = std::chrono::high_resolution_clock::now() - startTime;
const auto remaining =
std::chrono::duration_cast<std::chrono::milliseconds>(timeout) -
std::chrono::duration_cast<std::chrono::milliseconds>(elapsed);
pollTimeout = std::max(static_cast<int64_t>(0),
static_cast<int64_t>(remaining.count()));
}
int numReady = ::poll(&pfd, 1, pollTimeout);
if (numReady < 0) {
throw std::system_error(errno, std::system_category());
} else if (numReady == 0) {
errno = 0;
TORCH_CHECK(
false,
c10::str(
kConnectTimeoutMsg,
" Polled for ",
pollTimeout,
" ms with original timeout of ",
timeout.count(),
" ms."));
}
socklen_t errLen = sizeof(errno);
errno = 0;
::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &errLen);
// `errno` is set when:
// 1. `getsockopt` has failed
// 2. there is awaiting error in the socket
// (the error is saved to the `errno` variable)
if (errno != 0) {
throw std::system_error(errno, std::system_category());
}
// Disable non-blocking mode
int flags;
SYSCHECK_ERR_RETURN_NEG1(flags = ::fcntl(socket, F_GETFL));
SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK)));
}
// Linux socket does not need init libs first
inline void socketInitialize() {}
inline struct ::pollfd getPollfd(int socket, short events) {
struct ::pollfd res = {.fd = socket, .events = events};
return res;

View File

@ -1,16 +1,5 @@
#ifdef _WIN32
#include <c10d/WinSockUtils.hpp>
#else
#include <arpa/inet.h>
#include <c10d/UnixSockUtils.hpp>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/poll.h>
#include <unistd.h>
#endif
#include <c10d/Utils.hpp>
#include <fcntl.h>
#include <algorithm>
#include <cstring>
#include <memory>
@ -97,298 +86,4 @@ std::vector<at::Tensor> getTensorShapes(
return shapeTensors;
}
namespace tcputil {
namespace {
constexpr int LISTEN_QUEUE_SIZE = 2048;
void setSocketNoDelay(int socket) {
int flag = 1;
socklen_t optlen = sizeof(flag);
SYSCHECK_ERR_RETURN_NEG1(
setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen));
}
PortType getSocketPort(int fd) {
PortType listenPort;
struct ::sockaddr_storage addrStorage;
socklen_t addrLen = sizeof(addrStorage);
SYSCHECK_ERR_RETURN_NEG1(getsockname(
fd, reinterpret_cast<struct ::sockaddr*>(&addrStorage), &addrLen));
if (addrStorage.ss_family == AF_INET) {
struct ::sockaddr_in* addr =
reinterpret_cast<struct ::sockaddr_in*>(&addrStorage);
listenPort = ntohs(addr->sin_port);
} else if (addrStorage.ss_family == AF_INET6) { // AF_INET6
struct ::sockaddr_in6* addr =
reinterpret_cast<struct ::sockaddr_in6*>(&addrStorage);
listenPort = ntohs(addr->sin6_port);
} else {
TORCH_CHECK(false, "unsupported protocol");
}
return listenPort;
}
} // namespace
std::string sockaddrToString(struct ::sockaddr* addr) {
char address[INET6_ADDRSTRLEN + 1];
if (addr->sa_family == AF_INET) {
struct ::sockaddr_in* s = reinterpret_cast<struct ::sockaddr_in*>(addr);
SYSCHECK(
::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN),
__output != nullptr)
address[INET_ADDRSTRLEN] = '\0';
} else if (addr->sa_family == AF_INET6) {
struct ::sockaddr_in6* s = reinterpret_cast<struct ::sockaddr_in6*>(addr);
SYSCHECK(
::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN),
__output != nullptr)
address[INET6_ADDRSTRLEN] = '\0';
} else {
TORCH_CHECK(false, "unsupported protocol");
}
return address;
}
// listen, connect and accept
std::pair<int, PortType> listen(PortType port) {
struct ::addrinfo hints, *res = NULL;
std::memset(&hints, 0x00, sizeof(hints));
hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
hints.ai_family = AF_SELECTED; // IPv4 on Windows, IPv4/6 on Linux
hints.ai_socktype = SOCK_STREAM; // TCP
// `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
// by editing `/etc/gai.conf`. so there is no need to manual sorting
// or protocol preference.
int err = ::getaddrinfo(nullptr, std::to_string(port).data(), &hints, &res);
if (err != 0 || !res) {
throw std::invalid_argument(
"cannot find host to listen on: " + std::string(gai_strerror(err)));
}
std::shared_ptr<struct ::addrinfo> addresses(
res, [](struct ::addrinfo* p) { ::freeaddrinfo(p); });
struct ::addrinfo* nextAddr = addresses.get();
int socket;
while (true) {
try {
SYSCHECK_ERR_RETURN_NEG1(
socket = ::socket(
nextAddr->ai_family,
nextAddr->ai_socktype,
nextAddr->ai_protocol))
SYSCHECK_ERR_RETURN_NEG1(tcputil::setSocketAddrReUse(socket))
SYSCHECK_ERR_RETURN_NEG1(
::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen))
SYSCHECK_ERR_RETURN_NEG1(::listen(socket, LISTEN_QUEUE_SIZE))
break;
} catch (const std::system_error& e) {
tcputil::closeSocket(socket);
nextAddr = nextAddr->ai_next;
// we have tried all addresses but could not start
// listening on any of them
if (!nextAddr) {
throw;
}
}
}
// get listen port and address
return {socket, getSocketPort(socket)};
}
void handleConnectException(
struct ::addrinfo** nextAddr,
int error_code,
bool* anyRefused,
bool* anyReset,
bool wait,
std::chrono::time_point<std::chrono::high_resolution_clock> start,
std::shared_ptr<struct ::addrinfo> addresses,
std::chrono::milliseconds timeout) {
// ECONNREFUSED happens if the server is not yet listening.
if (error_code == ECONNREFUSED) {
*anyRefused = true;
}
// ECONNRESET happens if the server's listen backlog is exhausted.
if (error_code == ECONNRESET) {
*anyReset = true;
}
// We need to move to the next address because this was not available
// to connect or to create a socket.
*nextAddr = (*nextAddr)->ai_next;
// We have tried all addresses but could not connect to any of them.
if (!*nextAddr) {
if (!wait || (!anyRefused && !anyReset)) {
throw;
}
// if a timeout is specified, check time elapsed to see if we need to
// timeout. A timeout is specified if timeout != kNoTimeout.
if (timeout != kNoTimeout) {
const auto elapsed = std::chrono::high_resolution_clock::now() - start;
TORCH_CHECK(
elapsed <= timeout,
c10::str(
kConnectTimeoutMsg,
" Original timeout was ",
timeout.count(),
" ms."));
}
std::this_thread::sleep_for(std::chrono::seconds(1));
*anyRefused = false;
*anyReset = false;
*nextAddr = addresses.get();
}
}
void handleConnectSystemError(
struct ::addrinfo** nextAddr,
std::system_error& e,
bool* anyRefused,
bool* anyReset,
bool wait,
std::chrono::time_point<std::chrono::high_resolution_clock> start,
std::shared_ptr<struct ::addrinfo> addresses,
std::chrono::milliseconds timeout) {
handleConnectException(
nextAddr,
e.code().value(),
anyRefused,
anyReset,
wait,
start,
addresses,
timeout);
}
int connect(
const std::string& address,
PortType port,
bool wait,
const std::chrono::milliseconds& timeout) {
struct ::addrinfo hints, *res = NULL;
std::memset(&hints, 0x00, sizeof(hints));
hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric
hints.ai_family = AF_SELECTED; // IPv4 on Windows, IPv4/6 on Linux
hints.ai_socktype = SOCK_STREAM; // TCP
// `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
// by editing `/etc/gai.conf`. so there is no need to manual sorting
// or protcol preference.
int err =
::getaddrinfo(address.data(), std::to_string(port).data(), &hints, &res);
if (err != 0 || !res) {
throw std::invalid_argument(
"host not found: " + std::string(gai_strerror(err)));
}
std::shared_ptr<struct ::addrinfo> addresses(
res, [](struct ::addrinfo* p) { ::freeaddrinfo(p); });
struct ::addrinfo* nextAddr = addresses.get();
int socket;
// Loop over the addresses if at least one of them gave us ECONNREFUSED
// or ECONNRESET. This may happen if the server hasn't started listening
// yet, or is listening but has its listen backlog exhausted.
bool anyRefused = false;
bool anyReset = false;
const auto start = std::chrono::high_resolution_clock::now();
while (true) {
try {
SYSCHECK_ERR_RETURN_NEG1(
socket = ::socket(
nextAddr->ai_family,
nextAddr->ai_socktype,
nextAddr->ai_protocol))
ResourceGuard socketGuard([socket]() { tcputil::closeSocket(socket); });
// We need to connect in non-blocking mode, so we can use a timeout
waitSocketConnected(socket, nextAddr, timeout, start);
socketGuard.release();
break;
} catch (std::system_error& e) {
handleConnectSystemError(
&nextAddr,
e,
&anyRefused,
&anyReset,
wait,
start,
addresses,
timeout);
} catch (std::exception& e) {
handleConnectException(
&nextAddr,
errno,
&anyRefused,
&anyReset,
wait,
start,
addresses,
timeout);
}
}
setSocketNoDelay(socket);
return socket;
}
std::tuple<int, std::string> accept(
int listenSocket,
const std::chrono::milliseconds& timeout) {
// poll on listen socket, it allows to make timeout
std::unique_ptr<struct ::pollfd[]> events(new struct ::pollfd[1]);
events[0] = tcputil::getPollfd(listenSocket, POLLIN);
while (true) {
int res = tcputil::poll(events.get(), 1, timeout.count());
if (res == 0) {
TORCH_CHECK(
false,
"waiting for processes to "
"connect has timed out");
} else if (res == -1) {
if (errno == EINTR) {
continue;
}
throw std::system_error(errno, std::system_category());
} else {
if (!(events[0].revents & POLLIN))
throw std::system_error(ECONNABORTED, std::system_category());
break;
}
}
int socket;
SYSCHECK_ERR_RETURN_NEG1(socket = ::accept(listenSocket, NULL, NULL))
// Get address of the connecting process
struct ::sockaddr_storage addr;
socklen_t addrLen = sizeof(addr);
SYSCHECK_ERR_RETURN_NEG1(::getpeername(
socket, reinterpret_cast<struct ::sockaddr*>(&addr), &addrLen))
setSocketNoDelay(socket);
return std::make_tuple(
socket, sockaddrToString(reinterpret_cast<struct ::sockaddr*>(&addr)));
}
} // namespace tcputil
} // namespace c10d

View File

@ -482,7 +482,6 @@ size_t computeLengthsAndOffsets(
}
using RankType = uint32_t;
using PortType = uint16_t;
using SizeType = uint64_t;
// `errno` is only meaningful when it fails. E.g., a successful `fork()` sets
@ -536,32 +535,8 @@ using SizeType = uint64_t;
// Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1
#define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1)
// Helper resource guard class
class ResourceGuard {
public:
ResourceGuard(std::function<void()> destructor)
: destructor_(std::move(destructor)), released_(false) {}
~ResourceGuard() {
if (!released_) {
destructor_();
}
}
void release() {
released_ = true;
}
private:
std::function<void()> destructor_;
bool released_;
};
namespace tcputil {
constexpr std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds(-1);
const std::string kConnectTimeoutMsg = "connect() timed out.";
// Send and receive
template <typename T>
void sendBytes(
@ -677,20 +652,5 @@ inline std::string recvString(int socket) {
return std::string(value.data(), value.size());
}
// Other helpers
std::string sockaddrToString(struct sockaddr* addr);
std::pair<int, PortType> listen(PortType port);
int connect(
const std::string& address,
PortType port,
bool wait = true,
const std::chrono::milliseconds& timeout = kNoTimeout);
std::tuple<int, std::string> accept(
int listenSocket,
const std::chrono::milliseconds& timeout = kNoTimeout);
} // namespace tcputil
} // namespace c10d

View File

@ -5,17 +5,8 @@
namespace c10d {
namespace tcputil {
#define AF_SELECTED AF_INET
#define CONNECT_SOCKET_OFFSET 1
inline void closeSocket(int socket) { ::closesocket(socket); }
inline int setSocketAddrReUse(int socket) {
bool optval = false;
return ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char *)&optval,
sizeof(bool));
}
inline int poll(struct pollfd *fdArray, unsigned long fds, int timeout) {
return WSAPoll(fdArray, fds, timeout);
}
@ -25,62 +16,6 @@ inline void addPollfd(std::vector<struct pollfd> &fds, int socket,
fds.push_back({(SOCKET)socket, events});
}
inline void waitSocketConnected(
int socket,
struct ::addrinfo *nextAddr,
std::chrono::milliseconds timeout,
std::chrono::time_point<std::chrono::high_resolution_clock> startTime) {
unsigned long block_mode = 1;
SYSCHECK_ERR_RETURN_NEG1(ioctlsocket(socket, FIONBIO, &block_mode));
int ret;
do {
ret = connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen);
if (ret == SOCKET_ERROR) {
int err = WSAGetLastError();
if (err == WSAEISCONN) {
break;
} else if (err == WSAEALREADY || err == WSAEWOULDBLOCK) {
if (timeout != kNoTimeout) {
const auto elapsed =
std::chrono::high_resolution_clock::now() - startTime;
if (elapsed > timeout) {
errno = 0;
TORCH_CHECK(
false,
c10::str(
kConnectTimeoutMsg,
" Original timeout was ",
timeout.count(),
" ms."));
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
throw std::system_error(err, std::system_category(),
"Socket connect failed");
}
} while (ret == SOCKET_ERROR);
block_mode = 0;
SYSCHECK_ERR_RETURN_NEG1(ioctlsocket(socket, FIONBIO, &block_mode));
}
// All processes (applications or DLLs) that call Winsock
// functions must initialize the use of the Windows Sockets
// DLL before making other Winsock function calls.
// This also makes certain that Winsock is supported on the system.
// Ref to
// https://docs.microsoft.com/en-us/windows/win32/winsock/initializing-winsock
inline void socketInitialize() {
static std::once_flag init_flag;
std::call_once(init_flag, []() {
WSADATA wsa_data;
SYSCHECK_ERR_RETURN_NEG1(WSAStartup(MAKEWORD(2, 2), &wsa_data))
});
}
inline struct ::pollfd getPollfd(int socket, short events) {
struct ::pollfd res = {(SOCKET)socket, events};
return res;

View File

@ -0,0 +1,41 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <system_error>
#include <fmt/format.h>
namespace fmt {
template <>
struct formatter<std::error_code> {
constexpr decltype(auto) parse(format_parse_context& ctx) {
return ctx.begin();
}
template <typename FormatContext>
decltype(auto) format(const std::error_code& err, FormatContext& ctx) {
return format_to(ctx.out(),
"({} error: {} - {})",
err.category().name(),
err.value(),
err.message());
}
};
} // namespace fmt
namespace c10d {
namespace detail {
inline std::error_code lastError() noexcept {
return std::error_code{errno, std::generic_category()};
}
} // namespace detail
} // namespace c10d

View File

@ -0,0 +1,9 @@
#include <c10d/exception.h>
namespace c10d {
C10dError::~C10dError() = default;
TimeoutError::~TimeoutError() = default;
} // namespace c10d

View File

@ -0,0 +1,45 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <stdexcept>
#include <c10/macros/Macros.h>
namespace c10d {
class TORCH_API C10dError : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
C10dError(const C10dError&) = default;
C10dError& operator=(const C10dError&) = default;
C10dError(C10dError&&) = default;
C10dError& operator=(C10dError&&) = default;
~C10dError() override;
};
class TORCH_API TimeoutError : public C10dError {
public:
using C10dError::C10dError;
TimeoutError(const TimeoutError&) = default;
TimeoutError& operator=(const TimeoutError&) = default;
TimeoutError(TimeoutError&&) = default;
TimeoutError& operator=(TimeoutError&&) = default;
~TimeoutError() override;
};
} // namespace c10d

View File

@ -911,7 +911,7 @@ Example::
)")
.def(
py::init([](const std::string& host,
::c10d::PortType port,
uint16_t port,
int worldSize,
bool isServer,
std::chrono::milliseconds timeout,

View File

@ -0,0 +1,20 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <fmt/format.h>
#include <c10/util/Logging.h>
#define C10D_ERROR(...)\
LOG_IF(ERROR, FLAGS_caffe2_log_level <= 2) << fmt::format(__VA_ARGS__)
#define C10D_WARNING(...)\
LOG_IF(WARNING, FLAGS_caffe2_log_level <= 1) << fmt::format(__VA_ARGS__)
#define C10D_INFO(...)\
LOG_IF(INFO, FLAGS_caffe2_log_level <= 0) << fmt::format(__VA_ARGS__)

View File

@ -0,0 +1,847 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <c10d/socket.h>
#include <cstring>
#include <system_error>
#include <thread>
#include <utility>
#include <vector>
#ifdef _WIN32
#include <mutex>
#include <winsock2.h>
#include <ws2tcpip.h>
#else
#include <fcntl.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <poll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#endif
#include <fmt/chrono.h>
#include <fmt/format.h>
#include <c10d/error.h>
#include <c10d/exception.h>
#include <c10d/logging.h>
namespace c10d {
namespace detail {
namespace {
#ifdef _WIN32
// Since Winsock uses the name `WSAPoll` instead of `poll`, we alias it here
// to avoid #ifdefs in the source code.
const auto pollFd = ::WSAPoll;
// Winsock's `getsockopt()` and `setsockopt()` functions expect option values to
// be passed as `char*` instead of `void*`. We wrap them here to avoid redundant
// casts in the source code.
int getSocketOption(SOCKET s, int level, int optname, void* optval, int* optlen) {
return ::getsockopt(s, level, optname, static_cast<char*>(optval), optlen);
}
int setSocketOption(SOCKET s, int level, int optname, const void* optval, int optlen) {
return ::setsockopt(s, level, optname, static_cast<const char*>(optval), optlen);
}
// Winsock has its own error codes which differ from Berkeley's. Fortunately the
// C++ Standard Library on Windows can map them to standard error codes.
inline std::error_code getSocketError() noexcept {
return std::error_code{::WSAGetLastError(), std::system_category()};
}
inline void setSocketError(int val) noexcept {
::WSASetLastError(val);
}
#else
const auto pollFd = ::poll;
const auto getSocketOption = ::getsockopt;
const auto setSocketOption = ::setsockopt;
inline std::error_code getSocketError() noexcept {
return lastError();
}
inline void setSocketError(int val) noexcept {
errno = val;
}
#endif
// Suspends the current thread for the specified duration.
void delay(std::chrono::seconds d) {
#ifdef _WIN32
std::this_thread::sleep_for(d);
#else
::timespec req{};
req.tv_sec = d.count();
// The C++ Standard does not specify whether `sleep_for()` should be signal-
// aware; therefore, we use the `nanosleep()` syscall.
if (::nanosleep(&req, nullptr) != 0) {
std::error_code err = getSocketError();
// We don't care about error conditions other than EINTR since a failure
// here is not critical.
if (err == std::errc::interrupted) {
throw std::system_error{err};
}
}
#endif
}
class SocketListenOp;
class SocketConnectOp;
} // namespace
class SocketImpl {
friend class SocketListenOp;
friend class SocketConnectOp;
public:
#ifdef _WIN32
using Handle = SOCKET;
#else
using Handle = int;
#endif
#ifdef _WIN32
static constexpr Handle invalid_socket = INVALID_SOCKET;
#else
static constexpr Handle invalid_socket = -1;
#endif
explicit SocketImpl(Handle hnd) noexcept
: hnd_{hnd} {}
SocketImpl(const SocketImpl& other) = delete;
SocketImpl& operator=(const SocketImpl& other) = delete;
SocketImpl(SocketImpl&& other) noexcept = delete;
SocketImpl& operator=(SocketImpl&& other) noexcept = delete;
~SocketImpl();
std::unique_ptr<SocketImpl> accept() const;
void closeOnExec() noexcept;
void enableNonBlocking();
void disableNonBlocking();
bool enableNoDelay() noexcept;
bool enableDualStack() noexcept;
#ifndef _WIN32
bool enableAddressReuse() noexcept;
#endif
#ifdef _WIN32
bool enableExclusiveAddressUse() noexcept;
#endif
std::uint16_t getPort() const;
Handle handle() const noexcept {
return hnd_;
}
private:
bool setSocketFlag(int level, int optname, bool value) noexcept;
Handle hnd_;
};
SocketImpl::~SocketImpl() {
#ifdef _WIN32
::closesocket(hnd_);
#else
::close(hnd_);
#endif
}
std::unique_ptr<SocketImpl> SocketImpl::accept() const {
::sockaddr_storage addr_s{};
auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s);
::socklen_t addr_len = sizeof(addr_s);
Handle hnd = ::accept(hnd_, addr_ptr, &addr_len);
if (hnd == invalid_socket) {
std::error_code err = getSocketError();
if (err == std::errc::interrupted) {
throw std::system_error{err};
}
std::string msg{};
if (err == std::errc::invalid_argument) {
msg = fmt::format("The server socket on {} is not listening for connections.", *this);
} else {
msg = fmt::format("The server socket on {} has failed to accept a connection {}.", *this, err);
}
C10D_ERROR(msg);
throw SocketError{msg};
}
::addrinfo addr{};
addr.ai_addr = addr_ptr;
addr.ai_addrlen = addr_len;
C10D_INFO("The server socket on {} has accepted a connection from {}.", *this, addr);
auto impl = std::make_unique<SocketImpl>(hnd);
// Make sure that we do not "leak" our file descriptors to child processes.
impl->closeOnExec();
if (!impl->enableNoDelay()) {
C10D_WARNING("The no-delay option cannot be enabled for the client socket on {}.", addr);
}
return impl;
}
void SocketImpl::closeOnExec() noexcept {
#ifndef _WIN32
::fcntl(hnd_, F_SETFD, FD_CLOEXEC);
#endif
}
void SocketImpl::enableNonBlocking() {
#ifdef _WIN32
unsigned long value = 1;
if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) {
return;
}
#else
int flg = ::fcntl(hnd_, F_GETFL);
if (flg != -1) {
if (::fcntl(hnd_, F_SETFL, flg | O_NONBLOCK) == 0) {
return;
}
}
#endif
throw SocketError{"The socket cannot be switched to non-blocking mode."};
}
// TODO: Remove once we migrate everything to non-blocking mode.
void SocketImpl::disableNonBlocking() {
#ifdef _WIN32
unsigned long value = 0;
if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) {
return;
}
#else
int flg = ::fcntl(hnd_, F_GETFL);
if (flg != -1) {
if (::fcntl(hnd_, F_SETFL, flg & ~O_NONBLOCK) == 0) {
return;
}
}
#endif
throw SocketError{"The socket cannot be switched to blocking mode."};
}
bool SocketImpl::enableNoDelay() noexcept {
return setSocketFlag(IPPROTO_TCP, TCP_NODELAY, true);
}
bool SocketImpl::enableDualStack() noexcept {
return setSocketFlag(IPPROTO_IPV6, IPV6_V6ONLY, false);
}
#ifndef _WIN32
bool SocketImpl::enableAddressReuse() noexcept {
return setSocketFlag(SOL_SOCKET, SO_REUSEADDR, true);
}
#endif
#ifdef _WIN32
bool SocketImpl::enableExclusiveAddressUse() noexcept {
return setSocketFlag(SOL_SOCKET, SO_EXCLUSIVEADDRUSE, true);
}
#endif
std::uint16_t SocketImpl::getPort() const {
::sockaddr_storage addr_s{};
::socklen_t addr_len = sizeof(addr_s);
if (::getsockname(hnd_, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) != 0) {
throw SocketError{"The port number of the socket cannot be retrieved."};
}
if (addr_s.ss_family == AF_INET) {
return ntohs(reinterpret_cast<::sockaddr_in*> (&addr_s)->sin_port);
} else {
return ntohs(reinterpret_cast<::sockaddr_in6*>(&addr_s)->sin6_port);
}
}
bool SocketImpl::setSocketFlag(int level, int optname, bool value) noexcept {
#ifdef _WIN32
auto buf = value ? TRUE : FALSE;
#else
auto buf = value ? 1 : 0;
#endif
return setSocketOption(hnd_, level, optname, &buf, sizeof(buf)) == 0;
}
namespace {
struct addrinfo_delete {
void operator()(::addrinfo* addr) const noexcept {
::freeaddrinfo(addr);
}
};
using addrinfo_ptr = std::unique_ptr<::addrinfo, addrinfo_delete>;
class SocketListenOp {
public:
SocketListenOp(std::uint16_t port, const SocketOptions& opts);
std::unique_ptr<SocketImpl> run();
private:
bool tryListen(int family);
bool tryListen(const ::addrinfo& addr);
template <typename... Args>
void recordError(fmt::string_view format, Args&&... args) {
auto msg = fmt::format(format, std::forward<Args>(args)...);
C10D_WARNING(msg);
errors_.emplace_back(std::move(msg));
}
std::string port_;
const SocketOptions* opts_;
std::vector<std::string> errors_{};
std::unique_ptr<SocketImpl> socket_{};
};
SocketListenOp::SocketListenOp(std::uint16_t port, const SocketOptions& opts)
: port_{fmt::to_string(port)}, opts_{&opts} {}
std::unique_ptr<SocketImpl> SocketListenOp::run() {
if (opts_->prefer_ipv6()) {
C10D_INFO("The server socket will attempt to listen on an IPv6 address.");
if (tryListen(AF_INET6)) {
return std::move(socket_);
}
C10D_INFO("The server socket will attempt to listen on an IPv4 address.");
if (tryListen(AF_INET)) {
return std::move(socket_);
}
} else {
C10D_INFO("The server socket will attempt to listen on an IPv4 or IPv6 address.");
if (tryListen(AF_UNSPEC)) {
return std::move(socket_);
}
}
constexpr auto* msg = "The server socket has failed to listen on any local network address.";
C10D_ERROR(msg);
throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))};
}
bool SocketListenOp::tryListen(int family) {
::addrinfo hints{}, *naked_result = nullptr;
hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV;
hints.ai_family = family;
hints.ai_socktype = SOCK_STREAM;
int r = ::getaddrinfo(nullptr, port_.c_str(), &hints, &naked_result);
if (r != 0) {
const char* gai_err = ::gai_strerror(r);
recordError("The local {}network addresses cannot be retrieved (gai error: {} - {}).",
family == AF_INET ? "IPv4 " : family == AF_INET6 ? "IPv6 " : "",
r,
gai_err);
return false;
}
addrinfo_ptr result{naked_result};
for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) {
C10D_INFO("The server socket is attempting to listen on {}.", *addr);
if (tryListen(*addr)) {
return true;
}
}
return false;
}
bool SocketListenOp::tryListen(const ::addrinfo& addr) {
SocketImpl::Handle hnd = ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
if (hnd == SocketImpl::invalid_socket) {
recordError("The server socket cannot be initialized on {} {}.", addr, getSocketError());
return false;
}
socket_ = std::make_unique<SocketImpl>(hnd);
#ifndef _WIN32
if (!socket_->enableAddressReuse()) {
C10D_WARNING("The address reuse option cannot be enabled for the server socket on {}.", addr);
}
#endif
#ifdef _WIN32
// The SO_REUSEADDR flag has a significantly different behavior on Windows
// compared to Unix-like systems. It allows two or more processes to share
// the same port simultaneously, which is totally unsafe.
//
// Here we follow the recommendation of Microsoft and use the non-standard
// SO_EXCLUSIVEADDRUSE flag instead.
if (!socket_->enableExclusiveAddressUse()) {
C10D_WARNING("The exclusive address use option cannot be enabled for the server socket on {}.",
addr);
}
#endif
// Not all operating systems support dual-stack sockets by default. Since we
// wish to use our IPv6 socket for IPv4 communication as well, we explicitly
// ask the system to enable it.
if (addr.ai_family == AF_INET6 && !socket_->enableDualStack()) {
C10D_WARNING("The server socket does not support IPv4 communication on {}.", addr);
}
if (::bind(socket_->handle(), addr.ai_addr, addr.ai_addrlen) != 0) {
recordError("The server socket has failed to bind to {} {}.", addr, getSocketError());
return false;
}
// NOLINTNEXTLINE(bugprone-argument-comment)
if (::listen(socket_->handle(), /*backlog=*/2048) != 0) {
recordError("The server socket has failed to listen on {} {}.", addr, getSocketError());
return false;
}
socket_->closeOnExec();
C10D_INFO("The server socket has started to listen on {}.", addr);
return true;
}
class SocketConnectOp {
using Clock = std::chrono::steady_clock;
using Duration = std::chrono::steady_clock::duration;
using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
static const std::chrono::seconds delay_duration_;
enum class ConnectResult {
Success,
Error,
Retry,
TimeOut
};
public:
SocketConnectOp(const std::string& host, std::uint16_t port, const SocketOptions& opts);
std::unique_ptr<SocketImpl> run();
private:
bool tryConnect(int family);
ConnectResult tryConnect(const ::addrinfo& addr);
ConnectResult tryConnectCore(const ::addrinfo& addr);
template <typename... Args>
void recordError(fmt::string_view format, Args&&... args) {
auto msg = fmt::format(format, std::forward<Args>(args)...);
C10D_WARNING(msg);
errors_.emplace_back(std::move(msg));
}
const char* host_;
std::string port_;
const SocketOptions* opts_;
TimePoint deadline_{};
std::vector<std::string> errors_{};
std::unique_ptr<SocketImpl> socket_{};
};
const std::chrono::seconds SocketConnectOp::delay_duration_{1};
SocketConnectOp::SocketConnectOp(const std::string& host,
std::uint16_t port,
const SocketOptions& opts)
: host_{host.c_str()}, port_{fmt::to_string(port)}, opts_{&opts} {}
std::unique_ptr<SocketImpl> SocketConnectOp::run() {
if (opts_->prefer_ipv6()) {
C10D_INFO("The client socket will attempt to connect to an IPv6 address of ({}, {}).",
host_,
port_);
if (tryConnect(AF_INET6)) {
return std::move(socket_);
}
C10D_INFO("The client socket will attempt to connect to an IPv4 address of ({}, {}).",
host_,
port_);
if (tryConnect(AF_INET)) {
return std::move(socket_);
}
} else {
C10D_INFO("The client socket will attempt to connect to an IPv4 or IPv6 address of ({}, {}).",
host_,
port_);
if (tryConnect(AF_UNSPEC)) {
return std::move(socket_);
}
}
auto msg = fmt::format(
"The client socket has failed to connect to any network address of ({}, {}).",
host_,
port_);
C10D_ERROR(msg);
throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))};
}
bool SocketConnectOp::tryConnect(int family) {
::addrinfo hints{}, *naked_result = nullptr;
hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV;
hints.ai_family = family;
hints.ai_socktype = SOCK_STREAM;
int r = ::getaddrinfo(host_, port_.c_str(), &hints, &naked_result);
if (r != 0) {
const char* gai_err = ::gai_strerror(r);
recordError("The {}network addresses of ({}, {}) cannot be retrieved (gai error: {} - {}).",
family == AF_INET ? "IPv4 " : family == AF_INET6 ? "IPv6 " : "",
host_,
port_,
r,
gai_err);
return false;
}
addrinfo_ptr result{naked_result};
deadline_ = Clock::now() + opts_->connect_timeout();
bool retry; // NOLINT(cppcoreguidelines-init-variables)
do {
retry = false;
errors_.clear();
for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) {
C10D_INFO("The client socket is attempting to connect to {}.", *addr);
ConnectResult cr = tryConnect(*addr);
if (cr == ConnectResult::Success) {
return true;
}
if (cr == ConnectResult::TimeOut) {
auto msg = fmt::format(
"The client socket has timed out after {} while trying to connect to ({}, {}).",
opts_->connect_timeout(),
host_,
port_);
C10D_ERROR(msg);
throw TimeoutError{msg};
}
if (cr == ConnectResult::Retry) {
retry = true;
}
}
} while (retry);
return false;
}
SocketConnectOp::ConnectResult SocketConnectOp::tryConnect(const ::addrinfo& addr) {
if (Clock::now() >= deadline_) {
return ConnectResult::TimeOut;
}
SocketImpl::Handle hnd = ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
if (hnd == SocketImpl::invalid_socket) {
recordError("The client socket cannot be initialized to connect to {} {}.",
addr,
getSocketError());
return ConnectResult::Error;
}
socket_ = std::make_unique<SocketImpl>(hnd);
socket_->enableNonBlocking();
ConnectResult cr = tryConnectCore(addr);
if (cr == ConnectResult::Error) {
std::error_code err = getSocketError();
if (err == std::errc::interrupted) {
throw std::system_error{err};
}
// Retry if the server is not yet listening or if its backlog is exhausted.
if (err == std::errc::connection_refused || err == std::errc::connection_reset) {
C10D_WARNING("The server socket on {} is not yet listening {}.", addr, err);
if (Clock::now() < deadline_ - delay_duration_) {
// Wait a little to avoid choking the server.
delay(delay_duration_);
return ConnectResult::Retry;
} else {
return ConnectResult::TimeOut;
}
} else {
recordError("The client socket has failed to connect to {} {}.", addr, err);
return ConnectResult::Error;
}
}
if (cr == ConnectResult::TimeOut) {
return cr;
}
socket_->closeOnExec();
// TODO: Remove once we fully migrate to non-blocking mode.
socket_->disableNonBlocking();
C10D_INFO("The client socket has connected to {} on {}.", addr, *socket_);
if (!socket_->enableNoDelay()) {
C10D_WARNING("The no-delay option cannot be enabled for the client socket on {}.", *socket_);
}
return ConnectResult::Success;
}
SocketConnectOp::ConnectResult SocketConnectOp::tryConnectCore(const ::addrinfo& addr) {
int r = ::connect(socket_->handle(), addr.ai_addr, addr.ai_addrlen);
if (r == 0) {
return ConnectResult::Success;
}
std::error_code err = getSocketError();
if (err == std::errc::already_connected) {
return ConnectResult::Success;
}
if (err != std::errc::operation_in_progress && err != std::errc::operation_would_block) {
return ConnectResult::Error;
}
Duration remaining = deadline_ - Clock::now();
if (remaining <= Duration::zero()) {
return ConnectResult::TimeOut;
}
::pollfd pfd{};
pfd.fd = socket_->handle();
pfd.events = POLLOUT;
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(remaining);
r = pollFd(&pfd, 1, static_cast<int>(ms.count()));
if (r == 0) {
return ConnectResult::TimeOut;
}
if (r == -1) {
return ConnectResult::Error;
}
int err_code = 0;
::socklen_t err_len = sizeof(int);
r = getSocketOption(socket_->handle(), SOL_SOCKET, SO_ERROR, &err_code, &err_len);
if (r != 0) {
return ConnectResult::Error;
}
if (err_code != 0) {
setSocketError(err_code);
return ConnectResult::Error;
} else {
return ConnectResult::Success;
}
}
} // namespace
void Socket::initialize() {
#ifdef _WIN32
static std::once_flag init_flag{};
// All processes that call socket functions on Windows must first initialize
// the Winsock library.
std::call_once(init_flag, []() {
WSADATA data{};
if (::WSAStartup(MAKEWORD(2, 2), &data) != 0) {
throw SocketError{"The initialization of Winsock has failed."};
}
});
#endif
}
Socket Socket::listen(std::uint16_t port, const SocketOptions& opts) {
SocketListenOp op{port, opts};
return Socket{op.run()};
}
Socket Socket::connect(const std::string& host, std::uint16_t port, const SocketOptions& opts) {
SocketConnectOp op{host, port, opts};
return Socket{op.run()};
}
Socket::Socket(Socket&& other) noexcept = default;
Socket& Socket::operator=(Socket&& other) noexcept = default;
Socket::~Socket() = default;
Socket Socket::accept() const {
if (impl_) {
return Socket{impl_->accept()};
}
throw SocketError{"The socket is not initialized."};
}
int Socket::handle() const noexcept {
if (impl_) {
return impl_->handle();
}
return SocketImpl::invalid_socket;
}
std::uint16_t Socket::port() const {
if (impl_) {
return impl_->getPort();
}
return 0;
}
Socket::Socket(std::unique_ptr<SocketImpl>&& impl) noexcept
: impl_{std::move(impl)} {}
} // namespace detail
SocketError::~SocketError() = default;
} // namespace c10d
//
// libfmt formatters for `addrinfo` and `Socket`
//
namespace fmt {
template <>
struct formatter<::addrinfo> {
constexpr decltype(auto) parse(format_parse_context& ctx) {
return ctx.begin();
}
template <typename FormatContext>
decltype(auto) format(const ::addrinfo& addr, FormatContext& ctx) {
char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT
int r = ::getnameinfo(addr.ai_addr,
addr.ai_addrlen,
host,
NI_MAXHOST,
port,
NI_MAXSERV,
NI_NUMERICSERV);
if (r != 0) {
return format_to(ctx.out(), "?UNKNOWN?");
}
if (addr.ai_addr->sa_family == AF_INET) {
return format_to(ctx.out(), "{}:{}", host, port);
} else {
return format_to(ctx.out(), "[{}]:{}", host, port);
}
}
};
template <>
struct formatter<c10d::detail::SocketImpl> {
constexpr decltype(auto) parse(format_parse_context& ctx) {
return ctx.begin();
}
template <typename FormatContext>
decltype(auto) format(const c10d::detail::SocketImpl& socket, FormatContext& ctx) {
::sockaddr_storage addr_s{};
auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s);
::socklen_t addr_len = sizeof(addr_s);
if (::getsockname(socket.handle(), addr_ptr, &addr_len) != 0) {
return format_to(ctx.out(), "?UNKNOWN?");
}
::addrinfo addr{};
addr.ai_addr = addr_ptr;
addr.ai_addrlen = addr_len;
return format_to(ctx.out(), "{}", addr);
}
};
} // namespace fmt

View File

@ -0,0 +1,100 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <chrono>
#include <cstdint>
#include <memory>
#include <string>
#include <c10/macros/Macros.h>
#include <c10d/exception.h>
namespace c10d {
namespace detail {
class SocketOptions {
public:
SocketOptions& prefer_ipv6(bool value) noexcept {
prefer_ipv6_ = value;
return *this;
}
bool prefer_ipv6() const noexcept {
return prefer_ipv6_;
}
SocketOptions& connect_timeout(std::chrono::seconds value) noexcept {
connect_timeout_ = value;
return *this;
}
std::chrono::seconds connect_timeout() const noexcept {
return connect_timeout_;
}
private:
bool prefer_ipv6_ = true;
std::chrono::seconds connect_timeout_{30};
};
class SocketImpl;
class Socket {
public:
// This function initializes the underlying socket library and must be called
// before any other socket function.
static void initialize();
static Socket listen(std::uint16_t port, const SocketOptions& opts = {});
static Socket connect(const std::string& host, std::uint16_t port, const SocketOptions& opts = {});
Socket() noexcept = default;
Socket(const Socket& other) = delete;
Socket& operator=(const Socket& other) = delete;
Socket(Socket&& other) noexcept;
Socket& operator=(Socket&& other) noexcept;
~Socket();
Socket accept() const;
int handle() const noexcept;
std::uint16_t port() const;
private:
explicit Socket(std::unique_ptr<SocketImpl>&& impl) noexcept;
std::unique_ptr<SocketImpl> impl_;
};
} // namespace detail
class TORCH_API SocketError : public C10dError {
public:
using C10dError::C10dError;
SocketError(const SocketError&) = default;
SocketError& operator=(const SocketError&) = default;
SocketError(SocketError&&) = default;
SocketError& operator=(SocketError&&) = default;
~SocketError() override;
};
} // namespace c10d

View File

@ -111,7 +111,7 @@ class C10dRendezvousBackend(RendezvousBackend):
def _call_store(self, store_op: str, *args, **kwargs) -> Any:
try:
return getattr(self._store, store_op)(*args, **kwargs)
except (ValueError, RuntimeError) as exc:
except (ValueError, RuntimeError, TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
@ -164,7 +164,7 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
log.info(msg)
break
except (ValueError, RuntimeError) as exc:
except (ValueError, RuntimeError, TimeoutError) as exc:
# If we heuristically inferred the value of is_host as True and our
# first attempt to instantiate the TCP store has failed, try it one
# more time with is_host set to False. As an edge case there can be

View File

@ -16,7 +16,6 @@ from torch.distributed.elastic.utils.logging import get_logger
log = get_logger()
_ADDRESS_IN_USE = "Address already in use"
_CONNECT_TIMEOUT = "connect() timed out."
_SOCKET_TIMEOUT = "Socket Timeout"
_MEMBER_CHECKIN = "_tcp_store/num_members"
@ -75,18 +74,14 @@ def create_c10d_store(
# detects timeouts and port conflicts in their own unittests
# see - caffe2/torch/testing/_internal/common_utils.py
# TODO properly map the exceptions in pybind (c10d/init.cpp)
if _CONNECT_TIMEOUT in str(e) and not is_server:
raise TimeoutError(
f"timed out waiting for tcp store's server: {server_addr}:{port}"
) from e
elif str(e) == _ADDRESS_IN_USE: # this will only happen on the server
if str(e) == _ADDRESS_IN_USE: # this will only happen on the server
if attempt < retries:
log.warning(
f"port: {port} already in use, attempt: [{attempt}/{retries}]"
)
attempt += 1
else:
raise IOError(
raise RuntimeError(
f"on {server_addr}, port: {port} already in use"
) from e
else: