mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
The new backend is currently under a flag 'use_libuv' in TCPStore constructor to reduce the impact on existing users as we test it. This is a reland of #105870 with a fix for a bad test. Differential Revision: [D48742554](https://our.internmc.facebook.com/intern/diff/D48742554) Pull Request resolved: https://github.com/pytorch/pytorch/pull/108066 Approved by: https://github.com/H-Huang, https://github.com/fduwjj
1238 lines
31 KiB
C++
1238 lines
31 KiB
C++
#include <algorithm>
|
|
#include <deque>
|
|
#include <exception>
|
|
#include <memory>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
#include <fmt/format.h>
|
|
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
|
#include <torch/csrc/distributed/c10d/TCPStoreBackend.hpp>
|
|
#include <torch/csrc/distributed/c10d/logging.h>
|
|
|
|
#ifdef TORCH_USE_LIBUV
|
|
#include <uv.h>
|
|
#endif
|
|
|
|
namespace c10d {
|
|
namespace detail {
|
|
|
|
#ifdef TORCH_USE_LIBUV
|
|
|
|
/*
|
|
|
|
Exception safety:
|
|
|
|
It's ok to use exceptions during client processing.
|
|
Other callbacks don't provide exception safety so avoid there.
|
|
|
|
*/
|
|
|
|
#define DEFAULT_BACKLOG 2048
|
|
#define MAX_KEY_COUNT (128 * 1024)
|
|
#define MAX_STRING_LEN (8 * 1024)
|
|
#define MAX_PAYLOAD_LEN (8 * 1024 * 1024)
|
|
|
|
// This controls the preferred size for buffers.
|
|
// Too small and we'll need multiple buffers for one request
|
|
// Too big and we might taxing malloc
|
|
#define ALLOC_BUFFER_SIZE ((size_t)4000)
|
|
class UvHandle : public c10::intrusive_ptr_target {
|
|
public:
|
|
virtual ~UvHandle() {}
|
|
|
|
c10::intrusive_ptr<UvHandle> iptr() {
|
|
return c10::intrusive_ptr<UvHandle>::reclaim_copy(this);
|
|
}
|
|
|
|
void close() {
|
|
if (uv_is_closing(unsafeGetHandle())) {
|
|
return;
|
|
}
|
|
uv_close(unsafeGetHandle(), on_close);
|
|
}
|
|
|
|
virtual uv_handle_t* unsafeGetHandle() = 0;
|
|
|
|
protected:
|
|
void handleReady() {
|
|
/*
|
|
This method must be called once the handle is ready and registered with the
|
|
loop.
|
|
|
|
Do not call this in the ctor, make_intrusive reset refcounts to one after
|
|
construction.
|
|
*/
|
|
uv_handle_set_data(unsafeGetHandle(), this);
|
|
at::raw::intrusive_ptr::incref(this);
|
|
}
|
|
|
|
virtual void onClose() = 0;
|
|
|
|
private:
|
|
static c10::intrusive_ptr<UvHandle> reclaim(uv_handle_t* handle) {
|
|
auto h = (UvHandle*)uv_handle_get_data(handle);
|
|
return c10::intrusive_ptr<UvHandle>::reclaim(h);
|
|
}
|
|
|
|
static void on_close(uv_handle_t* uv_handle) {
|
|
auto handle = reclaim(uv_handle);
|
|
handle->onClose();
|
|
}
|
|
};
|
|
|
|
class UvTcpSocket : public UvHandle {
|
|
uv_tcp_t client;
|
|
|
|
c10::intrusive_ptr<UvTcpSocket> iptr() {
|
|
return c10::intrusive_ptr<UvTcpSocket>::reclaim_copy(this);
|
|
}
|
|
|
|
static c10::intrusive_ptr<UvTcpSocket> borrow(uv_stream_t* handle) {
|
|
auto h = (UvTcpSocket*)uv_handle_get_data((uv_handle_t*)handle);
|
|
return h->iptr();
|
|
}
|
|
|
|
static void alloc_buffer(
|
|
|
|
uv_handle_t* handle,
|
|
size_t suggested_size,
|
|
uv_buf_t* buf) {
|
|
suggested_size = std::min(suggested_size, (size_t)ALLOC_BUFFER_SIZE);
|
|
buf->base = (char*)malloc(suggested_size);
|
|
buf->len = suggested_size;
|
|
}
|
|
|
|
static void read_callback(
|
|
uv_stream_t* client,
|
|
ssize_t nread,
|
|
const uv_buf_t* buf) {
|
|
auto uv_socket = UvTcpSocket::borrow(client);
|
|
|
|
if (nread < 0) {
|
|
C10D_DEBUG(
|
|
"Read callback failed. code:{} name:{} desc:{}",
|
|
nread,
|
|
uv_err_name(nread),
|
|
uv_strerror(nread));
|
|
uv_socket->close();
|
|
return;
|
|
}
|
|
if (nread > 0) {
|
|
try {
|
|
uv_socket->processBuf(buf, nread);
|
|
} catch (std::exception& ex) {
|
|
C10D_INFO("Error processing client message: {}", ex.what());
|
|
uv_socket->close();
|
|
}
|
|
}
|
|
}
|
|
|
|
public:
|
|
explicit UvTcpSocket(uv_loop_t* loop) {
|
|
uv_tcp_init(loop, &client);
|
|
}
|
|
|
|
void startRead() {
|
|
int res = uv_read_start((uv_stream_t*)&client, alloc_buffer, read_callback);
|
|
if (res) {
|
|
C10D_INFO(
|
|
"Failed to setup read callback. client:{} code:{} name:{} desc:{}.",
|
|
(void*)this,
|
|
res,
|
|
uv_err_name(res),
|
|
uv_strerror(res));
|
|
close();
|
|
}
|
|
}
|
|
|
|
uv_handle_t* unsafeGetHandle() override {
|
|
return (uv_handle_t*)&client;
|
|
}
|
|
|
|
protected:
|
|
uv_stream_t* unsafeGetStream() {
|
|
return (uv_stream_t*)&client;
|
|
}
|
|
|
|
uv_tcp_t* unsafeGetSocket() {
|
|
return &client;
|
|
}
|
|
|
|
virtual void processBuf(const uv_buf_t* buf, size_t nread) {
|
|
TORCH_CHECK(
|
|
false, "Trying to read from a socket subclass that lacks processBuf");
|
|
}
|
|
|
|
void onClose() override {
|
|
// TODO use registerClient (and rename it to registerHandle) - this will
|
|
// significantly simplify things.
|
|
}
|
|
};
|
|
|
|
class UvTcpServer : public UvTcpSocket {
|
|
public:
|
|
typedef std::function<void(int)> OnConnectCallback;
|
|
explicit UvTcpServer(uv_loop_t* loop)
|
|
: UvTcpSocket(loop), onConnectCb(missingOnConnect) {}
|
|
|
|
static c10::intrusive_ptr<UvTcpServer> makeWithSocket(
|
|
uv_loop_t* loop,
|
|
int socket) {
|
|
auto res = c10::make_intrusive<UvTcpServer>(loop);
|
|
res->handleReady();
|
|
try {
|
|
int uv_res = uv_tcp_open((uv_tcp_t*)res->unsafeGetStream(), socket);
|
|
TORCH_CHECK(
|
|
uv_res == 0,
|
|
"Failed to open existing socket. socket:%d code:{} name:{} message:{}",
|
|
socket,
|
|
uv_res,
|
|
uv_err_name(uv_res),
|
|
uv_strerror(uv_res));
|
|
|
|
res->cacheSocketPort();
|
|
} catch (std::exception& ex) {
|
|
res->close();
|
|
throw;
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
void setOnConnectCallback(OnConnectCallback&& callback) {
|
|
onConnectCb = std::move(callback);
|
|
}
|
|
|
|
static c10::intrusive_ptr<UvTcpServer> makeWithPort(
|
|
uv_loop_t* loop,
|
|
uint16_t port,
|
|
bool useIpv6) {
|
|
auto res = c10::make_intrusive<UvTcpServer>(loop);
|
|
res->handleReady();
|
|
try {
|
|
struct sockaddr_storage addr;
|
|
int uv_res;
|
|
if (useIpv6) {
|
|
uv_res = uv_ip6_addr("::", port, (struct sockaddr_in6*)&addr);
|
|
} else {
|
|
uv_res = uv_ip4_addr("0.0.0.0", port, (struct sockaddr_in*)&addr);
|
|
}
|
|
TORCH_CHECK(
|
|
uv_res == 0,
|
|
"UV Store addr parsing failure. useIpv6:{} code:{} name:{} message:{}",
|
|
useIpv6,
|
|
uv_res,
|
|
uv_err_name(uv_res),
|
|
uv_strerror(uv_res));
|
|
|
|
uv_res =
|
|
uv_tcp_bind(res->unsafeGetSocket(), (const struct sockaddr*)&addr, 0);
|
|
TORCH_CHECK(
|
|
uv_res == 0,
|
|
"UV Store bind failed. useIpv6:{} code:{} name:{} message:{}",
|
|
useIpv6,
|
|
uv_res,
|
|
uv_err_name(uv_res),
|
|
uv_strerror(uv_res));
|
|
|
|
uv_res =
|
|
uv_listen(res->unsafeGetStream(), DEFAULT_BACKLOG, on_new_connection);
|
|
TORCH_CHECK(
|
|
uv_res == 0,
|
|
"UV Store listen failed. useIpv6:{} code:{} name:{} message:{}",
|
|
useIpv6,
|
|
uv_res,
|
|
uv_err_name(uv_res),
|
|
uv_strerror(uv_res));
|
|
|
|
res->cacheSocketPort();
|
|
} catch (std::exception& ex) {
|
|
res->close();
|
|
throw;
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
uint16_t port() const {
|
|
return portNum;
|
|
}
|
|
|
|
void accept(c10::intrusive_ptr<UvTcpSocket> socket) {
|
|
int res =
|
|
uv_accept(unsafeGetStream(), (uv_stream_t*)socket->unsafeGetHandle());
|
|
TORCH_CHECK(
|
|
"Failed to accept socket. code:{} name:{} desc:{}.",
|
|
res,
|
|
uv_err_name(res),
|
|
uv_strerror(res));
|
|
}
|
|
|
|
private:
|
|
OnConnectCallback onConnectCb;
|
|
uint16_t portNum;
|
|
|
|
c10::intrusive_ptr<UvTcpServer> iptr() {
|
|
return c10::intrusive_ptr<UvTcpServer>::reclaim_copy(this);
|
|
}
|
|
|
|
static c10::intrusive_ptr<UvTcpServer> borrow(uv_stream_t* handle) {
|
|
auto h = (UvTcpServer*)uv_handle_get_data((uv_handle_t*)handle);
|
|
return h->iptr();
|
|
}
|
|
|
|
void cacheSocketPort() {
|
|
sockaddr_storage addr_s{};
|
|
|
|
int addr_len = sizeof(addr_s);
|
|
|
|
if (uv_tcp_getsockname(
|
|
(uv_tcp_t*)unsafeGetStream(),
|
|
reinterpret_cast<sockaddr*>(&addr_s),
|
|
&addr_len) != 0) {
|
|
throw std::runtime_error(
|
|
"The port number of the socket cannot be retrieved.");
|
|
}
|
|
|
|
if (addr_s.ss_family == AF_INET) {
|
|
portNum = ntohs(reinterpret_cast<sockaddr_in*>(&addr_s)->sin_port);
|
|
} else {
|
|
portNum = ntohs(reinterpret_cast<sockaddr_in6*>(&addr_s)->sin6_port);
|
|
}
|
|
}
|
|
|
|
static void missingOnConnect(int status) {
|
|
TORCH_CHECK(false, "Socket accepted byt onConnect callback missing");
|
|
}
|
|
|
|
static void on_new_connection(uv_stream_t* server, int status) {
|
|
borrow(server)->onConnectCb(status);
|
|
}
|
|
};
|
|
|
|
class WriterPayload : public c10::intrusive_ptr_target {
|
|
static c10::intrusive_ptr<WriterPayload> reclaim(uv_write_t* request) {
|
|
/* This method returns a intrusive_ptr that does not increase the refcount.
|
|
*/
|
|
auto h = (WriterPayload*)uv_req_get_data((uv_req_t*)request);
|
|
return c10::intrusive_ptr<WriterPayload>::reclaim(h);
|
|
}
|
|
|
|
void registeredInLoop() {
|
|
/*
|
|
This refcount increment must be matched by a reclaim call.
|
|
Call this method after sucessfully scheduling this handle with a loop.
|
|
*/
|
|
at::raw::intrusive_ptr::incref(this);
|
|
}
|
|
|
|
static void write_done(uv_write_t* req, int status) {
|
|
/* Since we're no longer actively used by the event loop, transfer ownership
|
|
* to this frame. */
|
|
auto wp = WriterPayload::reclaim(req);
|
|
auto handle = wp->handle;
|
|
|
|
if (status) {
|
|
C10D_INFO(
|
|
"Write to client failed. code:{} name:{} desc:{}.",
|
|
status,
|
|
uv_err_name(status),
|
|
uv_strerror(status));
|
|
handle->close();
|
|
}
|
|
}
|
|
|
|
std::vector<uint8_t> data;
|
|
uv_write_t req = {};
|
|
uv_buf_t buf = {};
|
|
c10::intrusive_ptr<UvHandle> handle;
|
|
|
|
public:
|
|
WriterPayload(
|
|
std::vector<uint8_t>&& in_data,
|
|
c10::intrusive_ptr<UvHandle> handle)
|
|
: data(std::move(in_data)), handle(handle) {
|
|
uv_req_set_data((uv_req_t*)&req, this);
|
|
}
|
|
|
|
~WriterPayload() {}
|
|
|
|
void send() {
|
|
buf = uv_buf_init((char*)data.data(), data.size());
|
|
int res = uv_write(
|
|
&req, (uv_stream_t*)handle->unsafeGetHandle(), &buf, 1, write_done);
|
|
|
|
if (res) {
|
|
C10D_INFO(
|
|
"Write setup to client failed. code:{} name:{} desc:{}.",
|
|
res,
|
|
uv_err_name(res),
|
|
uv_strerror(res));
|
|
handle->close();
|
|
} else {
|
|
/* This object was successfully registered with the event loop, so keep it
|
|
* alive until it's unregistered. */
|
|
registeredInLoop();
|
|
}
|
|
}
|
|
};
|
|
|
|
class StreamWriter {
|
|
std::vector<uint8_t> data;
|
|
c10::intrusive_ptr<UvHandle> handle;
|
|
|
|
// must be stack allocated
|
|
void* operator new(size_t);
|
|
|
|
public:
|
|
StreamWriter(c10::intrusive_ptr<UvHandle> handle) : handle(handle) {}
|
|
|
|
void write1(uint8_t val) {
|
|
data.push_back(val);
|
|
}
|
|
|
|
template <typename T>
|
|
void write_value(T val) {
|
|
uint8_t* val_ptr = (uint8_t*)&val;
|
|
data.insert(data.end(), val_ptr, val_ptr + sizeof(T));
|
|
}
|
|
|
|
void write_vector(const std::vector<uint8_t>& val) {
|
|
write_value<uint64_t>(val.size());
|
|
data.insert(data.end(), val.begin(), val.end());
|
|
}
|
|
|
|
void write_string(const std::string& val) {
|
|
write_value<uint64_t>(val.size());
|
|
data.insert(data.end(), val.data(), val.data() + val.size());
|
|
}
|
|
void send() {
|
|
auto wd = c10::make_intrusive<WriterPayload>(std::move(data), handle);
|
|
wd->send();
|
|
}
|
|
};
|
|
|
|
class ChunkedStream {
|
|
std::deque<uv_buf_t> buffers;
|
|
size_t buff_idx;
|
|
size_t buff_offset;
|
|
size_t capacity;
|
|
size_t buff_offset_commit;
|
|
size_t read_offset;
|
|
|
|
public:
|
|
ChunkedStream()
|
|
: buff_idx(0),
|
|
buff_offset(0),
|
|
capacity(0),
|
|
buff_offset_commit(0),
|
|
read_offset(0) {}
|
|
|
|
size_t buf_count() {
|
|
return buffers.size();
|
|
}
|
|
|
|
void append(uv_buf_t buf) {
|
|
if (buf.len == 0) {
|
|
free(buf.base);
|
|
} else {
|
|
capacity += buf.len;
|
|
buffers.push_back(buf);
|
|
}
|
|
}
|
|
bool read_many(char* dest, size_t size) {
|
|
if (available() < size) {
|
|
return false;
|
|
}
|
|
|
|
size_t remaining = size;
|
|
char* write_base = dest;
|
|
while (remaining > 0) {
|
|
auto to_read = std::min(buffers[buff_idx].len - buff_offset, remaining);
|
|
::memcpy(write_base, buffers[buff_idx].base + buff_offset, to_read);
|
|
buff_offset += to_read;
|
|
remaining -= to_read;
|
|
write_base += to_read;
|
|
if (buff_offset >= buffers[buff_idx].len) {
|
|
buff_offset = 0;
|
|
++buff_idx;
|
|
if (buff_idx >= buffers.size() && remaining > 0) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Trying to read past end of buffer buffer_idx:{} available:{} remaining:{}",
|
|
buff_idx,
|
|
buffers.size(),
|
|
remaining);
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool read1(uint8_t& byte) {
|
|
while (true) {
|
|
if (buff_idx >= buffers.size())
|
|
return false;
|
|
if (buff_offset >= buffers[buff_idx].len) {
|
|
buff_offset = 0;
|
|
++buff_idx;
|
|
continue;
|
|
}
|
|
break;
|
|
}
|
|
|
|
byte = buffers[buff_idx].base[buff_offset];
|
|
++buff_offset;
|
|
++read_offset;
|
|
return true;
|
|
}
|
|
|
|
template <typename T>
|
|
bool read_value(T& value) {
|
|
return read_many((char*)&value, sizeof(T));
|
|
}
|
|
|
|
bool read_key(std::string& str) {
|
|
uint64_t size = 0;
|
|
if (!read_value(size))
|
|
return false;
|
|
TORCH_CHECK(
|
|
size <= MAX_STRING_LEN,
|
|
"Invalid string size. size:{} max:{}",
|
|
size,
|
|
MAX_STRING_LEN);
|
|
|
|
if (available() < size)
|
|
return false;
|
|
str.resize(size);
|
|
return read_many((char*)str.data(), size);
|
|
}
|
|
|
|
bool read_payload(std::vector<uint8_t>& data) {
|
|
uint64_t size = 0;
|
|
if (!read_value(size))
|
|
return false;
|
|
auto size_in_bytes = size * sizeof(uint8_t);
|
|
TORCH_CHECK(
|
|
size_in_bytes <= MAX_PAYLOAD_LEN,
|
|
"Invalid payload size. size: {} max:{}",
|
|
size_in_bytes,
|
|
MAX_PAYLOAD_LEN);
|
|
|
|
if (available() < size_in_bytes)
|
|
return false;
|
|
data.resize(size);
|
|
return read_many((char*)data.data(), size_in_bytes);
|
|
}
|
|
|
|
size_t available() {
|
|
return capacity - read_offset;
|
|
}
|
|
|
|
void commit() {
|
|
if (buff_idx >= buffers.size() || buff_offset >= buffers[buff_idx].len) {
|
|
buff_offset = 0;
|
|
if (buff_idx < buffers.size())
|
|
++buff_idx;
|
|
}
|
|
|
|
for (size_t i = 0; i < buff_idx; ++i) {
|
|
free(buffers[0].base);
|
|
capacity -= buffers[0].len;
|
|
buffers.pop_front();
|
|
}
|
|
buff_idx = 0;
|
|
read_offset = buff_offset_commit = buff_offset;
|
|
}
|
|
|
|
void reset() {
|
|
buff_idx = 0;
|
|
read_offset = buff_offset = buff_offset_commit;
|
|
}
|
|
};
|
|
|
|
class LibUVStoreDaemon : public BackgroundThread {
|
|
public:
|
|
explicit LibUVStoreDaemon(int port);
|
|
~LibUVStoreDaemon() override;
|
|
|
|
uint16_t port() const override;
|
|
|
|
void set(const std::string& key, const std::vector<uint8_t>& value);
|
|
const std::vector<uint8_t>& compareAndSet(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& expectedValue,
|
|
const std::vector<uint8_t>& newValue);
|
|
const std::vector<uint8_t>& get(const std::string& key);
|
|
int64_t add(const std::string& key, int64_t addVal);
|
|
bool checkKeys(const std::vector<std::string>& keys);
|
|
bool waitKeys(
|
|
const std::vector<std::string>& keys,
|
|
c10::intrusive_ptr<UvHandle> client);
|
|
int64_t size();
|
|
int64_t deleteKey(const std::string& key);
|
|
void append(const std::string& key, const std::vector<uint8_t>& value);
|
|
|
|
void registerClient(c10::intrusive_ptr<UvHandle> client);
|
|
void unregisterClient(c10::intrusive_ptr<UvHandle> client);
|
|
void clearClientWaitState(c10::intrusive_ptr<UvHandle> client);
|
|
|
|
uint16_t get_socket_port(uv_tcp_t* handle);
|
|
void init(const TCPStoreOptions& opts);
|
|
|
|
protected:
|
|
void run() override;
|
|
void stop() override;
|
|
|
|
private:
|
|
uv_loop_t loop;
|
|
c10::intrusive_ptr<UvTcpServer> tcpServer;
|
|
|
|
uv_async_t exit_handle;
|
|
std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
|
|
// From key -> the list of UvClient waiting on the key
|
|
std::unordered_map<std::string, std::vector<c10::intrusive_ptr<UvHandle>>>
|
|
waitingSockets_;
|
|
// From socket -> number of keys awaited
|
|
std::unordered_map<c10::intrusive_ptr<UvHandle>, size_t> keysAwaited_;
|
|
std::unordered_set<c10::intrusive_ptr<UvHandle>> clients_;
|
|
int port_;
|
|
|
|
static LibUVStoreDaemon& from_uv(uv_handle_t* stream) {
|
|
return *(LibUVStoreDaemon*)uv_handle_get_data(stream);
|
|
}
|
|
|
|
static void on_new_connection(uv_stream_t* server, int status) {
|
|
from_uv((uv_handle_t*)server).onConnect(status);
|
|
}
|
|
|
|
static void on_exit_request(uv_async_t* handle) {
|
|
from_uv((uv_handle_t*)handle).onExitRequest();
|
|
}
|
|
|
|
void onConnect(int status);
|
|
void onExitRequest();
|
|
void wakeupWaitingClients(const std::string& key);
|
|
// bool tryListen(bool use_ipv6);
|
|
|
|
static void print_active_handles(uv_handle_t* handle, void* arg);
|
|
};
|
|
|
|
class UvClient : public UvTcpSocket {
|
|
ChunkedStream stream;
|
|
LibUVStoreDaemon* store;
|
|
|
|
protected:
|
|
void processBuf(const uv_buf_t* buf, size_t nread) override {
|
|
auto tmp = *buf;
|
|
tmp.len = nread;
|
|
stream.append(tmp);
|
|
|
|
while (true) {
|
|
stream.reset();
|
|
uint8_t command = -1;
|
|
if (!stream.read1(command))
|
|
break;
|
|
switch ((QueryType)command) {
|
|
case QueryType::SET:
|
|
if (!parse_set_command())
|
|
return;
|
|
break;
|
|
case QueryType::COMPARE_SET:
|
|
if (!parse_compare_set_command())
|
|
return;
|
|
break;
|
|
case QueryType::GET:
|
|
if (!parse_get_command())
|
|
return;
|
|
break;
|
|
case QueryType::ADD:
|
|
if (!parse_add_command())
|
|
return;
|
|
break;
|
|
case QueryType::CHECK:
|
|
if (!parse_check_command())
|
|
return;
|
|
break;
|
|
case QueryType::WAIT:
|
|
if (!parse_wait_command())
|
|
return;
|
|
break;
|
|
case QueryType::GETNUMKEYS:
|
|
if (!parse_getnumkeys_command())
|
|
return;
|
|
break;
|
|
case QueryType::DELETE_KEY:
|
|
if (!parse_delete_key_command())
|
|
return;
|
|
break;
|
|
case QueryType::APPEND:
|
|
if (!parse_append_command())
|
|
return;
|
|
break;
|
|
case QueryType::MULTI_GET:
|
|
if (!parse_multi_get_command())
|
|
return;
|
|
break;
|
|
case QueryType::MULTI_SET:
|
|
if (!parse_multi_set_command())
|
|
return;
|
|
break;
|
|
case QueryType::CANCEL_WAIT:
|
|
if (!parse_cancel_wait_command())
|
|
return;
|
|
break;
|
|
default:
|
|
C10D_DEBUG(
|
|
"Client sent invalid command. client:{} command:{}",
|
|
(void*)this,
|
|
(int)command);
|
|
close();
|
|
return;
|
|
}
|
|
stream.commit();
|
|
}
|
|
}
|
|
|
|
bool parse_set_command() {
|
|
std::string key;
|
|
if (!stream.read_key(key))
|
|
return false;
|
|
|
|
std::vector<uint8_t> newData;
|
|
if (!stream.read_payload(newData))
|
|
return false;
|
|
|
|
store->set(key, newData);
|
|
return true;
|
|
}
|
|
|
|
bool parse_compare_set_command() {
|
|
std::string key;
|
|
if (!stream.read_key(key))
|
|
return false;
|
|
|
|
std::vector<uint8_t> currentValue;
|
|
if (!stream.read_payload(currentValue))
|
|
return false;
|
|
|
|
std::vector<uint8_t> newValue;
|
|
if (!stream.read_payload(newValue))
|
|
return false;
|
|
|
|
auto res = store->compareAndSet(key, currentValue, newValue);
|
|
StreamWriter sw(iptr());
|
|
sw.write_vector(res);
|
|
sw.send();
|
|
|
|
return true;
|
|
}
|
|
|
|
bool parse_get_command() {
|
|
std::string key;
|
|
if (!stream.read_key(key))
|
|
return false;
|
|
|
|
auto data = store->get(key);
|
|
StreamWriter sw(iptr());
|
|
sw.write_vector(data);
|
|
sw.send();
|
|
return true;
|
|
}
|
|
|
|
bool parse_add_command() {
|
|
std::string key;
|
|
if (!stream.read_key(key))
|
|
return false;
|
|
|
|
int64_t addVal = 0;
|
|
if (!stream.read_value(addVal))
|
|
return false;
|
|
|
|
addVal = store->add(key, addVal);
|
|
StreamWriter sw(iptr());
|
|
sw.write_value(addVal);
|
|
sw.send();
|
|
|
|
return true;
|
|
}
|
|
|
|
bool parse_check_command() {
|
|
uint64_t key_count = 0;
|
|
if (!stream.read_value(key_count))
|
|
return false;
|
|
TORCH_CHECK(
|
|
key_count <= MAX_KEY_COUNT,
|
|
"Too many keys being waited. keys:{} max:{}",
|
|
key_count,
|
|
MAX_KEY_COUNT);
|
|
|
|
std::vector<std::string> keys(key_count);
|
|
for (uint64_t i = 0; i < key_count; ++i) {
|
|
if (!stream.read_key(keys[i]))
|
|
return false;
|
|
}
|
|
|
|
// Now we have received all the keys
|
|
StreamWriter sw(iptr());
|
|
if (store->checkKeys(keys)) {
|
|
sw.write_value(CheckResponseType::READY);
|
|
} else {
|
|
sw.write_value(CheckResponseType::NOT_READY);
|
|
}
|
|
sw.send();
|
|
return true;
|
|
}
|
|
|
|
bool parse_wait_command() {
|
|
uint64_t key_count = 0;
|
|
if (!stream.read_value(key_count)) {
|
|
return false;
|
|
}
|
|
TORCH_CHECK(
|
|
key_count <= MAX_KEY_COUNT,
|
|
"Too many keys being waited. keys:{} max:{}",
|
|
key_count,
|
|
MAX_KEY_COUNT);
|
|
|
|
std::vector<std::string> keys(key_count);
|
|
for (uint64_t i = 0; i < key_count; ++i) {
|
|
if (!stream.read_key(keys[i]))
|
|
return false;
|
|
}
|
|
|
|
if (store->waitKeys(keys, iptr())) {
|
|
StreamWriter sw(iptr());
|
|
sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
|
|
sw.send();
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool parse_getnumkeys_command() {
|
|
StreamWriter sw(iptr());
|
|
sw.write_value<int64_t>(store->size());
|
|
sw.send();
|
|
|
|
return true;
|
|
}
|
|
|
|
bool parse_delete_key_command() {
|
|
std::string key;
|
|
if (!stream.read_key(key))
|
|
return false;
|
|
|
|
auto numDeleted = store->deleteKey(key);
|
|
StreamWriter sw(iptr());
|
|
sw.write_value<int64_t>(numDeleted);
|
|
sw.send();
|
|
|
|
return true;
|
|
}
|
|
|
|
bool parse_append_command() {
|
|
std::string key;
|
|
if (!stream.read_key(key)) {
|
|
return false;
|
|
}
|
|
|
|
std::vector<uint8_t> data;
|
|
if (!stream.read_payload(data)) {
|
|
return false;
|
|
}
|
|
|
|
store->append(key, data);
|
|
return true;
|
|
}
|
|
|
|
bool parse_multi_get_command() {
|
|
uint64_t key_count = 0;
|
|
if (!stream.read_value(key_count)) {
|
|
return false;
|
|
}
|
|
TORCH_CHECK(
|
|
key_count <= MAX_KEY_COUNT,
|
|
"Too many keys with multi_get. keys:{} max:{}",
|
|
key_count,
|
|
MAX_KEY_COUNT);
|
|
|
|
StreamWriter sw(iptr());
|
|
for (const auto _ : c10::irange(key_count)) {
|
|
(void)_; // Suppress unused variable warning
|
|
std::string key;
|
|
if (!stream.read_key(key)) {
|
|
return false;
|
|
}
|
|
|
|
auto data = store->get(key);
|
|
sw.write_vector(data);
|
|
}
|
|
sw.send();
|
|
|
|
return true;
|
|
}
|
|
|
|
bool parse_multi_set_command() {
|
|
uint64_t key_count = 0;
|
|
if (!stream.read_value(key_count)) {
|
|
return false;
|
|
}
|
|
TORCH_CHECK(
|
|
key_count <= MAX_KEY_COUNT,
|
|
"Too many keys with multi_get. keys:{} max:{}",
|
|
key_count,
|
|
MAX_KEY_COUNT);
|
|
|
|
for (const auto _ : c10::irange(key_count)) {
|
|
(void)_; // Suppress unused variable warning
|
|
|
|
std::string key;
|
|
if (!stream.read_key(key)) {
|
|
return false;
|
|
}
|
|
|
|
std::vector<uint8_t> newData;
|
|
if (!stream.read_payload(newData))
|
|
return false;
|
|
store->set(key, newData);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool parse_cancel_wait_command() {
|
|
store->clearClientWaitState(iptr());
|
|
|
|
StreamWriter sw(iptr());
|
|
sw.write1((uint8_t)WaitResponseType::WAIT_CANCELED);
|
|
sw.send();
|
|
|
|
return true;
|
|
}
|
|
|
|
public:
|
|
explicit UvClient(uv_loop_t* loop, LibUVStoreDaemon* store)
|
|
: UvTcpSocket(loop), store(store) {}
|
|
|
|
static c10::intrusive_ptr<UvClient> make(
|
|
uv_loop_t* loop,
|
|
LibUVStoreDaemon* store) {
|
|
auto res = c10::make_intrusive<UvClient>(loop, store);
|
|
res->handleReady();
|
|
return res;
|
|
}
|
|
|
|
c10::intrusive_ptr<UvClient> iptr() {
|
|
return c10::intrusive_ptr<UvClient>::reclaim_copy(this);
|
|
}
|
|
|
|
protected:
|
|
void onClose() override {
|
|
store->unregisterClient(iptr());
|
|
}
|
|
};
|
|
|
|
void LibUVStoreDaemon::onConnect(int status) {
|
|
auto client = UvClient::make(&loop, this);
|
|
registerClient(client);
|
|
try {
|
|
tcpServer->accept(client);
|
|
client->startRead();
|
|
} catch (std::exception& e) {
|
|
C10D_INFO("Failed to accept client due to {}", e.what());
|
|
client->close();
|
|
}
|
|
}
|
|
|
|
void LibUVStoreDaemon::onExitRequest() {
|
|
C10D_DEBUG("Store exit requested\n");
|
|
uv_close((uv_handle_t*)&exit_handle, nullptr);
|
|
uv_stop(&loop);
|
|
}
|
|
|
|
void LibUVStoreDaemon::init(const TCPStoreOptions& opts) {
|
|
if (opts.masterListenFd.has_value()) {
|
|
tcpServer = UvTcpServer::makeWithSocket(&loop, *opts.masterListenFd);
|
|
} else {
|
|
try {
|
|
tcpServer = UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/true);
|
|
} catch (std::exception& ex) {
|
|
C10D_INFO(
|
|
"Failed to bind to ipv6 address, trying ipv4. Error: {}", ex.what());
|
|
tcpServer =
|
|
UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/false);
|
|
}
|
|
}
|
|
tcpServer->setOnConnectCallback(
|
|
[this](auto status) { this->onConnect(status); });
|
|
|
|
port_ = tcpServer->port();
|
|
TORCH_CHECK(
|
|
port_ == opts.port || opts.port == 0, // zero means use any port
|
|
"listen fd {} is bound to port {}, expected to be bound to port {}",
|
|
*opts.masterListenFd,
|
|
port_,
|
|
opts.port);
|
|
}
|
|
|
|
LibUVStoreDaemon::LibUVStoreDaemon(int port) : port_(port) {
|
|
TORCH_CHECK(uv_loop_init(&loop) == 0, "Failed to init uv loop");
|
|
TORCH_CHECK(
|
|
uv_async_init(&loop, &exit_handle, LibUVStoreDaemon::on_exit_request) ==
|
|
0,
|
|
"Failed to init uv async event");
|
|
uv_handle_set_data((uv_handle_t*)&exit_handle, this);
|
|
}
|
|
|
|
LibUVStoreDaemon::~LibUVStoreDaemon() {
|
|
if (!is_running()) {
|
|
uv_close((uv_handle_t*)&exit_handle, nullptr);
|
|
uv_run(&loop, UV_RUN_NOWAIT);
|
|
TORCH_CHECK(uv_loop_close(&loop) == 0, "loop cleanup didn't work");
|
|
} else {
|
|
// the daemon thread cleanup libuv
|
|
dispose();
|
|
}
|
|
}
|
|
|
|
uint16_t LibUVStoreDaemon::port() const {
|
|
return port_;
|
|
}
|
|
|
|
void LibUVStoreDaemon::print_active_handles(uv_handle_t* handle, void* arg) {
|
|
C10D_DEBUG(
|
|
"UV live handle type {} active:{} is-closing:{}",
|
|
(int)handle->type,
|
|
uv_is_active(handle),
|
|
uv_is_closing(handle));
|
|
}
|
|
|
|
void LibUVStoreDaemon::run() {
|
|
C10D_DEBUG("Uv main loop running");
|
|
int res = uv_run(&loop, UV_RUN_DEFAULT);
|
|
if (res) {
|
|
C10D_DEBUG("UV main loop done: res:{}", res);
|
|
}
|
|
bool debug_enabled =
|
|
c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Debug);
|
|
|
|
if (debug_enabled) {
|
|
C10D_DEBUG("Walking live handles prior to closing clients");
|
|
uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr);
|
|
}
|
|
|
|
for (auto it = clients_.begin(); it != clients_.end(); ++it) {
|
|
(*it)->close();
|
|
}
|
|
tcpServer->close();
|
|
|
|
if (debug_enabled) {
|
|
C10D_DEBUG("Walking live handles after closing clients");
|
|
uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr);
|
|
}
|
|
|
|
while (1) {
|
|
res = uv_loop_close(&loop);
|
|
if (res == 0) {
|
|
break;
|
|
}
|
|
C10D_INFO(
|
|
"uv_loop_close failed with:{} errn:{} desc:{}",
|
|
res,
|
|
uv_err_name(res),
|
|
uv_strerror(res));
|
|
res = uv_run(&loop, UV_RUN_NOWAIT);
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
|
}
|
|
C10D_INFO("uv_loop cleanup finished.");
|
|
}
|
|
|
|
void LibUVStoreDaemon::stop() {
|
|
int res = uv_async_send(&exit_handle);
|
|
if (res) {
|
|
C10D_INFO(
|
|
"uv_async_send failed with:{} errn:{} desc:{}\n",
|
|
res,
|
|
uv_err_name(res),
|
|
uv_strerror(res));
|
|
}
|
|
}
|
|
|
|
void LibUVStoreDaemon::registerClient(c10::intrusive_ptr<UvHandle> client) {
|
|
clients_.insert(client);
|
|
}
|
|
|
|
void LibUVStoreDaemon::unregisterClient(c10::intrusive_ptr<UvHandle> client) {
|
|
clients_.erase(client);
|
|
clearClientWaitState(client);
|
|
}
|
|
|
|
void LibUVStoreDaemon::clearClientWaitState(
|
|
c10::intrusive_ptr<UvHandle> client) {
|
|
if (keysAwaited_.find(client) == keysAwaited_.end()) {
|
|
return;
|
|
}
|
|
keysAwaited_.erase(client);
|
|
for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
|
|
for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
|
|
if (*vecIt == client) {
|
|
vecIt = it->second.erase(vecIt);
|
|
} else {
|
|
++vecIt;
|
|
}
|
|
}
|
|
if (it->second.empty()) {
|
|
it = waitingSockets_.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
}
|
|
|
|
void LibUVStoreDaemon::set(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& value) {
|
|
tcpStore_[key] = value;
|
|
// On "set", wake up all clients that have been waiting
|
|
wakeupWaitingClients(key);
|
|
}
|
|
|
|
const std::vector<uint8_t>& LibUVStoreDaemon::compareAndSet(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& expectedValue,
|
|
const std::vector<uint8_t>& newValue) {
|
|
auto pos = tcpStore_.find(key);
|
|
if (pos == tcpStore_.end()) {
|
|
if (expectedValue.empty()) {
|
|
tcpStore_[key] = newValue;
|
|
wakeupWaitingClients(key);
|
|
return newValue;
|
|
} else {
|
|
// TODO: This code path is not ideal as we are "lying" to the caller in
|
|
// case the key does not exist. We should come up with a working solution.
|
|
// It might make more sense to return ""
|
|
wakeupWaitingClients(key);
|
|
return expectedValue;
|
|
}
|
|
} else {
|
|
if (pos->second == expectedValue) {
|
|
pos->second = std::move(newValue);
|
|
}
|
|
wakeupWaitingClients(key);
|
|
return pos->second;
|
|
}
|
|
}
|
|
|
|
const std::vector<uint8_t>& LibUVStoreDaemon::get(const std::string& key) {
|
|
static std::vector<uint8_t> missing_key;
|
|
return tcpStore_.count(key) ? tcpStore_.at(key) : missing_key;
|
|
}
|
|
|
|
int64_t LibUVStoreDaemon::add(const std::string& key, int64_t addVal) {
|
|
std::vector<uint8_t> oldData;
|
|
auto it = tcpStore_.find(key);
|
|
if (it != tcpStore_.end()) {
|
|
oldData = it->second;
|
|
auto buf = reinterpret_cast<const char*>(it->second.data());
|
|
auto len = it->second.size();
|
|
addVal += std::stoll(std::string(buf, len));
|
|
}
|
|
auto addValStr = std::to_string(addVal);
|
|
std::vector<uint8_t> newData =
|
|
std::vector<uint8_t>(addValStr.begin(), addValStr.end());
|
|
tcpStore_[key] = newData;
|
|
|
|
// On "add", wake up all clients that have been waiting
|
|
wakeupWaitingClients(key);
|
|
|
|
return addVal;
|
|
}
|
|
|
|
bool LibUVStoreDaemon::checkKeys(const std::vector<std::string>& keys) {
|
|
return std::all_of(keys.begin(), keys.end(), [&](const std::string& s) {
|
|
return tcpStore_.count(s) > 0;
|
|
});
|
|
}
|
|
|
|
bool LibUVStoreDaemon::waitKeys(
|
|
const std::vector<std::string>& keys,
|
|
c10::intrusive_ptr<UvHandle> client) {
|
|
if (checkKeys(keys)) {
|
|
return true;
|
|
}
|
|
int numKeysToAwait = 0;
|
|
for (auto& key : keys) {
|
|
// Only count keys that have not already been set
|
|
if (tcpStore_.find(key) == tcpStore_.end()) {
|
|
waitingSockets_[key].push_back(client);
|
|
numKeysToAwait++;
|
|
}
|
|
}
|
|
keysAwaited_[client] = numKeysToAwait;
|
|
return false;
|
|
}
|
|
|
|
int64_t LibUVStoreDaemon::size() {
|
|
return tcpStore_.size();
|
|
}
|
|
|
|
int64_t LibUVStoreDaemon::deleteKey(const std::string& key) {
|
|
return tcpStore_.erase(key);
|
|
}
|
|
|
|
void LibUVStoreDaemon::append(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& value) {
|
|
std::vector<uint8_t> oldData;
|
|
auto it = tcpStore_.find(key);
|
|
if (it != tcpStore_.end()) {
|
|
it->second.insert(it->second.end(), value.begin(), value.end());
|
|
} else {
|
|
tcpStore_[key] = value;
|
|
}
|
|
|
|
// we should not have clients waiting if we're appending, so it's all fine
|
|
wakeupWaitingClients(key);
|
|
}
|
|
|
|
void LibUVStoreDaemon::wakeupWaitingClients(const std::string& key) {
|
|
auto socketsToWait = waitingSockets_.find(key);
|
|
if (socketsToWait != waitingSockets_.end()) {
|
|
for (auto client : socketsToWait->second) {
|
|
if (--keysAwaited_[client] == 0) {
|
|
StreamWriter sw(client->iptr());
|
|
sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
|
|
sw.send();
|
|
}
|
|
}
|
|
waitingSockets_.erase(socketsToWait);
|
|
}
|
|
}
|
|
|
|
#endif
|
|
|
|
std::unique_ptr<BackgroundThread> create_libuv_tcpstore_backend(
|
|
const TCPStoreOptions& opts) {
|
|
#ifdef TORCH_USE_LIBUV
|
|
auto res = std::make_unique<LibUVStoreDaemon>(opts.port);
|
|
res->init(opts);
|
|
return res;
|
|
#else
|
|
TORCH_CHECK(false, "LibUV TCPStore implementation missing");
|
|
#endif
|
|
}
|
|
|
|
bool is_libuv_tcpstore_backend_available() {
|
|
#ifdef TORCH_USE_LIBUV
|
|
return true;
|
|
#else
|
|
return false;
|
|
#endif
|
|
}
|
|
|
|
} // namespace detail
|
|
} // namespace c10d
|