diff --git a/BUILD.bazel b/BUILD.bazel
index 71ebc296598..b58fb57199f 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -769,6 +769,7 @@ cc_library(
":caffe2",
":torch_headers",
"@kineto",
+ "@cpp-httplib",
] + if_cuda([
"@cuda//:nvToolsExt",
"@cutlass",
diff --git a/WORKSPACE b/WORKSPACE
index 5b4f2f2e337..4169e0dbce1 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -168,6 +168,12 @@ new_local_repository(
path = "third_party/opentelemetry-cpp",
)
+new_local_repository(
+ name = "cpp-httplib",
+ build_file = "//third_party:cpp-httplib.BUILD",
+ path = "third_party/cpp-httplib",
+)
+
new_local_repository(
name = "tensorpipe",
build_file = "//third_party:tensorpipe.BUILD",
diff --git a/build_variables.bzl b/build_variables.bzl
index 8b5ac4f46d7..20822ba95cf 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -515,6 +515,8 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/sequence_num.cpp",
"torch/csrc/distributed/c10d/socket.cpp",
"torch/csrc/distributed/c10d/Work.cpp",
+ "torch/csrc/distributed/c10d/control_plane/Handlers.cpp",
+ "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp",
]
# These files are only supported on Linux (and others) but not on Windows.
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index e9b2b20ce6a..fe24571b66a 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -1179,6 +1179,9 @@ if(USE_KINETO)
${TORCH_ROOT}/third_party/kineto/libkineto/src)
endif()
+target_include_directories(torch_cpu PRIVATE
+ ${TORCH_ROOT}/third_party/cpp-httplib)
+
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp")
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index 8c7751f4c07..9693ac6e9fe 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1681,3 +1681,7 @@ endif()
# Include google/FlatBuffers
include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake)
+
+# Include cpp-httplib
+add_library(httplib INTERFACE IMPORTED)
+target_include_directories(httplib SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/cpp-httplib)
diff --git a/docs/source/distributed.elastic.rst b/docs/source/distributed.elastic.rst
index 24d33d1982d..0aabb560c9c 100644
--- a/docs/source/distributed.elastic.rst
+++ b/docs/source/distributed.elastic.rst
@@ -29,6 +29,7 @@ Documentation
elastic/metrics
elastic/events
elastic/subprocess_handler
+ elastic/control_plane
.. toctree::
:maxdepth: 1
diff --git a/docs/source/elastic/control_plane.rst b/docs/source/elastic/control_plane.rst
new file mode 100644
index 00000000000..c37454cf1b0
--- /dev/null
+++ b/docs/source/elastic/control_plane.rst
@@ -0,0 +1,10 @@
+Control Plane
+=============
+
+.. automodule:: torch.distributed.elastic.control_plane
+.. currentmodule:: torch.distributed.elastic.control_plane
+
+This module contains optional helpers that add extra debug and control handlers
+into your application.
+
+.. autofunction:: torch.distributed.elastic.control_plane.worker_main
diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py
new file mode 100644
index 00000000000..c9ae512f271
--- /dev/null
+++ b/test/distributed/elastic/test_control_plane.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+# Owner(s): ["oncall: distributed"]
+
+import json
+import os
+import pickle
+import socket
+import tempfile
+from contextlib import contextmanager
+
+from urllib3.connection import HTTPConnection
+from urllib3.connectionpool import HTTPConnectionPool
+
+from torch.distributed.elastic.control_plane import (
+ TORCH_WORKER_SERVER_SOCKET,
+ worker_main,
+)
+from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
+
+
+class UnixHTTPConnection(HTTPConnection):
+ def __init__(self, socket_path: str) -> None:
+ super().__init__("localhost")
+
+ self.socket_path = socket_path
+
+ def connect(self) -> None:
+ self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ self.sock.connect(self.socket_path)
+
+
+class UnixHTTPConnectionPool(HTTPConnectionPool):
+ def __init__(self, socket_path: str) -> None:
+ super().__init__("localhost")
+
+ self.socket_path = socket_path
+
+ def _new_conn(self):
+ return UnixHTTPConnection(self.socket_path)
+
+
+@contextmanager
+def local_worker_server() -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ socket_path = os.path.join(tmpdir, "socket.sock")
+ os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path
+
+ with worker_main():
+ pool = UnixHTTPConnectionPool(socket_path)
+ yield pool
+
+
+class WorkerServerTest(TestCase):
+ def test_worker_server(self) -> None:
+ with local_worker_server() as pool:
+ resp = pool.request("GET", "/")
+ self.assertEqual(resp.status, 200)
+ self.assertEqual(
+ resp.data,
+ b"""
torch.distributed.WorkerServer
+Handler names
+""",
+ )
+
+ resp = pool.request("POST", "/handler/ping")
+ self.assertEqual(resp.status, 200)
+ self.assertEqual(resp.data, b"pong")
+
+ resp = pool.request("GET", "/handler/")
+ self.assertEqual(resp.status, 200)
+ self.assertIn("ping", json.loads(resp.data))
+
+ resp = pool.request("POST", "/handler/nonexistant")
+ self.assertEqual(resp.status, 404)
+ self.assertIn(b"Handler nonexistant not found:", resp.data)
+
+ @requires_cuda
+ def test_dump_nccl_trace_pickle(self) -> None:
+ with local_worker_server() as pool:
+ resp = pool.request("POST", "/handler/dump_nccl_trace_pickle")
+ self.assertEqual(resp.status, 200)
+ out = pickle.loads(resp.data)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/third_party/cpp-httplib.BUILD b/third_party/cpp-httplib.BUILD
new file mode 100644
index 00000000000..3cd0c3dbe94
--- /dev/null
+++ b/third_party/cpp-httplib.BUILD
@@ -0,0 +1,10 @@
+load("@rules_cc//cc:defs.bzl", "cc_library")
+
+cc_library(
+ name = "cpp-httplib",
+ hdrs = ["httplib.h"],
+ includes = [
+ "/",
+ ],
+ visibility = ["//visibility:public"],
+)
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index fa62688c7e8..10a44af747b 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -68,6 +68,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES
${TORCH_ROOT}/third_party/onnx
${TORCH_ROOT}/third_party/flatbuffers/include
${TORCH_ROOT}/third_party/kineto/libkineto/include
+ ${TORCH_ROOT}/third_party/cpp-httplib
${TORCH_SRC_DIR}/csrc
${TORCH_SRC_DIR}/csrc/api/include
@@ -80,6 +81,7 @@ set(TORCH_PYTHON_LINK_LIBRARIES
Python::Module
pybind::pybind11
opentelemetry::api
+ httplib
shm
fmt::fmt-header-only
ATEN_CPU_FILES_GEN_LIB)
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 1a3e4ea6334..dab215d396c 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -94,6 +94,10 @@ class Logger:
def _set_uneven_input_join(self) -> None: ...
def _set_static_graph(self) -> None: ...
+class _WorkerServer:
+ def __init__(self, socket_path: str) -> None: ...
+ def shutdown(self) -> None: ...
+
def get_debug_level(): ...
def set_debug_level(): ...
def set_debug_level_from_env(): ...
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 1c0bdc43be3..2e55bfdb6f3 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -28,6 +28,7 @@
#include
#include
#include
+#include
#include
#include
@@ -369,6 +370,13 @@ std::string dump_nccl_trace() {
}
#endif
+// TODO(c-p-i-o): add a JSON endpoint.
+control_plane::RegisterHandler dumpHandler{
+ "dump_nccl_trace_pickle",
+ [](const control_plane::Request&, control_plane::Response& res) {
+ res.setContent(dump_nccl_trace(), "application/octet-stream");
+ }};
+
std::optional)>>&
get_cpp_trace_dumper() {
static std::optional<
diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp
new file mode 100644
index 00000000000..e29f1e3a2ac
--- /dev/null
+++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp
@@ -0,0 +1,75 @@
+#include
+
+#include
+#include
+#include
+#include
+
+namespace c10d {
+namespace control_plane {
+
+namespace {
+
+class HandlerRegistry {
+ public:
+ void registerHandler(const std::string& name, HandlerFunc f) {
+ std::unique_lock lock(handlersMutex_);
+
+ if (handlers_.find(name) != handlers_.end()) {
+ throw std::runtime_error(
+ fmt::format("Handler {} already registered", name));
+ }
+
+ handlers_[name] = f;
+ }
+
+ HandlerFunc getHandler(const std::string& name) {
+ std::shared_lock lock(handlersMutex_);
+
+ auto it = handlers_.find(name);
+ if (it == handlers_.end()) {
+ throw std::runtime_error(fmt::format("Failed to find handler {}", name));
+ }
+ return handlers_[name];
+ }
+
+ std::vector getHandlerNames() {
+ std::shared_lock lock(handlersMutex_);
+
+ std::vector names;
+ for (const auto& [name, _] : handlers_) {
+ names.push_back(name);
+ }
+ return names;
+ }
+
+ private:
+ std::shared_mutex handlersMutex_{};
+ std::unordered_map handlers_{};
+};
+
+HandlerRegistry& getHandlerRegistry() {
+ static HandlerRegistry registry;
+ return registry;
+}
+
+RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
+ res.setContent("pong", "text/plain");
+ }};
+
+} // namespace
+
+void registerHandler(const std::string& name, HandlerFunc f) {
+ return getHandlerRegistry().registerHandler(name, f);
+}
+
+HandlerFunc getHandler(const std::string& name) {
+ return getHandlerRegistry().getHandler(name);
+}
+
+std::vector getHandlerNames() {
+ return getHandlerRegistry().getHandlerNames();
+}
+
+} // namespace control_plane
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp
new file mode 100644
index 00000000000..0c106305493
--- /dev/null
+++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp
@@ -0,0 +1,67 @@
+#pragma once
+
+#include
+#include
+
+#include
+
+namespace c10d {
+namespace control_plane {
+
+// Request represents a request to the handler. This conceptually maps to an
+// HTTP request but could be called via other transports.
+class TORCH_API Request {
+ public:
+ virtual ~Request() = default;
+
+ virtual const std::string& body() = 0;
+};
+
+// Response represents a response to the handler. This conceptually maps to an
+// HTTP response but could be called via other transports.
+class TORCH_API Response {
+ public:
+ virtual ~Response() = default;
+
+ // Set the response body to the provided string.
+ // TODO: add support for chunked responses
+ virtual void setContent(
+ std::string&& content,
+ const std::string& content_type) = 0;
+
+ // Set the response status code.
+ // These should match standard HTTP status codes.
+ virtual void setStatus(int status) = 0;
+};
+
+using HandlerFunc = std::function;
+
+// Registers a handler. The name needs to be unique and can be called by using
+// getHandler directly or via WorkerServer for remote requests.
+// These handlers are called from a background C++ thread concurrently with the
+// main thread. These handlers need to be thread safe and not cause issues
+// during Python training.
+TORCH_API void registerHandler(const std::string& name, HandlerFunc f);
+
+// Fetches a handler by name.
+TORCH_API HandlerFunc getHandler(const std::string& name);
+
+TORCH_API std::vector getHandlerNames();
+
+// Registers a handler statically.
+// See registerHandler for more details.
+class TORCH_API RegisterHandler {
+ public:
+ RegisterHandler(const std::string& name, HandlerFunc f) {
+ registerHandler(name, f);
+ }
+
+ // disable move, copy
+ RegisterHandler(const RegisterHandler&) = delete;
+ RegisterHandler(RegisterHandler&&) = delete;
+ RegisterHandler& operator=(const RegisterHandler&) = delete;
+ RegisterHandler& operator=(RegisterHandler&&) = delete;
+};
+
+} // namespace control_plane
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp
new file mode 100644
index 00000000000..14d287e9607
--- /dev/null
+++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp
@@ -0,0 +1,178 @@
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+namespace c10d {
+namespace control_plane {
+
+namespace {
+class RequestImpl : public Request {
+ public:
+ RequestImpl(const httplib::Request& req) : req_(req) {}
+
+ const std::string& body() override {
+ return req_.body;
+ }
+
+ private:
+ const httplib::Request& req_;
+};
+
+class ResponseImpl : public Response {
+ public:
+ ResponseImpl(httplib::Response& res) : res_(res) {}
+
+ void setStatus(int status) override {
+ res_.status = status;
+ }
+
+ void setContent(std::string&& content, const std::string& content_type)
+ override {
+ res_.set_content(std::move(content), content_type);
+ }
+
+ private:
+ httplib::Response& res_;
+};
+
+std::string jsonStrEscape(const std::string& str) {
+ std::ostringstream ostream;
+ for (char ch : str) {
+ if (ch == '"') {
+ ostream << "\\\"";
+ } else if (ch == '\\') {
+ ostream << "\\\\";
+ } else if (ch == '\b') {
+ ostream << "\\b";
+ } else if (ch == '\f') {
+ ostream << "\\f";
+ } else if (ch == '\n') {
+ ostream << "\\n";
+ } else if (ch == '\r') {
+ ostream << "\\r";
+ } else if (ch == '\t') {
+ ostream << "\\t";
+ } else if ('\x00' <= ch && ch <= '\x1f') {
+ ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0')
+ << static_cast(ch);
+ } else {
+ ostream << ch;
+ }
+ }
+ return ostream.str();
+}
+} // namespace
+
+WorkerServer::WorkerServer(const std::string& socketFile) {
+ // using unix sockets
+ server_.set_address_family(AF_UNIX);
+
+ // adjust keep alives as it stops the server from shutting down quickly
+ server_.set_keep_alive_timeout(1); // second, default is 5
+ server_.set_keep_alive_max_count(
+ 30); // wait max 30 seconds before closing socket
+
+ server_.Get("/", [](const httplib::Request& req, httplib::Response& res) {
+ res.set_content(
+ R"BODY(torch.distributed.WorkerServer
+Handler names
+)BODY",
+ "text/html");
+ });
+ server_.Get(
+ "/handler/", [](const httplib::Request& req, httplib::Response& res) {
+ std::ostringstream body;
+ body << "[";
+ bool first = true;
+ for (const auto& name : getHandlerNames()) {
+ if (!first) {
+ body << ",";
+ }
+ first = false;
+
+ body << "\"" << jsonStrEscape(name) << "\"";
+ }
+ body << "]";
+
+ res.set_content(body.str(), "application/json");
+ });
+ server_.Post(
+ "/handler/:handler",
+ [](const httplib::Request& req, httplib::Response& res) {
+ auto handler_name = req.path_params.at("handler");
+ HandlerFunc handler;
+ try {
+ handler = getHandler(handler_name);
+ } catch (const std::exception& e) {
+ res.status = 404;
+ res.set_content(
+ fmt::format("Handler {} not found: {}", handler_name, e.what()),
+ "text/plain");
+ return;
+ }
+ RequestImpl torchReq{req};
+ ResponseImpl torchRes{res};
+
+ try {
+ handler(torchReq, torchRes);
+ } catch (const std::exception& e) {
+ res.status = 500;
+ res.set_content(
+ fmt::format("Handler {} failed: {}", handler_name, e.what()),
+ "text/plain");
+ return;
+ } catch (...) {
+ res.status = 500;
+ res.set_content(
+ fmt::format(
+ "Handler {} failed with unknown exception", handler_name),
+ "text/plain");
+ return;
+ }
+ });
+
+ if (std::filesystem::exists(socketFile)) {
+ throw std::runtime_error(fmt::format("{} already exists", socketFile));
+ }
+
+ C10D_WARNING("Server listening to {}", socketFile);
+ if (!server_.bind_to_port(socketFile, 80)) {
+ throw std::runtime_error(fmt::format("Error binding to {}", socketFile));
+ }
+
+ serverThread_ = std::thread([this]() {
+ try {
+ if (!server_.listen_after_bind()) {
+ throw std::runtime_error("failed to listen");
+ }
+ } catch (std::exception& e) {
+ C10D_ERROR("Error while running server: {}", e.what());
+ throw;
+ }
+ C10D_WARNING("Server exited");
+ });
+}
+
+void WorkerServer::shutdown() {
+ C10D_WARNING("Server shutting down");
+ server_.stop();
+ serverThread_.join();
+}
+
+WorkerServer::~WorkerServer() {
+ if (serverThread_.joinable()) {
+ C10D_WARNING("WorkerServer destructor called without shutdown");
+ shutdown();
+ }
+}
+
+} // namespace control_plane
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp
new file mode 100644
index 00000000000..7d64038f0b0
--- /dev/null
+++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp
@@ -0,0 +1,28 @@
+#pragma once
+
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+
+namespace c10d {
+namespace control_plane {
+
+class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
+ public:
+ WorkerServer(const std::string& socketFile);
+ ~WorkerServer();
+
+ void shutdown();
+
+ private:
+ httplib::Server server_;
+ std::thread serverThread_;
+};
+
+} // namespace control_plane
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 6f6dae32606..c4b9a9823c8 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -8,6 +8,7 @@
#include
#include
#include
+#include
#include
#ifndef _WIN32
#include
@@ -3164,6 +3165,17 @@ such as `dist.all_reduce(tensor, async_op=True)`.
return py::bytes(::c10d::dump_nccl_trace());
});
#endif
+
+ intrusive_ptr_class_<::c10d::control_plane::WorkerServer>(
+ module, "_WorkerServer", R"(
+)")
+ .def(
+ py::init([](const std::string& socketPath) {
+ return c10::make_intrusive<::c10d::control_plane::WorkerServer>(
+ socketPath);
+ }),
+ py::arg("socket_path"))
+ .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
Py_RETURN_TRUE;
}
diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py
new file mode 100644
index 00000000000..16038363786
--- /dev/null
+++ b/torch/distributed/elastic/control_plane.py
@@ -0,0 +1,51 @@
+import os
+from contextlib import contextmanager, ExitStack
+from typing import Generator
+
+from torch.distributed.elastic.multiprocessing.errors import record
+
+__all__ = [
+ "worker_main",
+]
+
+TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
+
+
+@contextmanager
+def _worker_server(socket_path: str) -> Generator[None, None, None]:
+ from torch._C._distributed_c10d import _WorkerServer
+
+ server = _WorkerServer(socket_path)
+ try:
+ yield
+ finally:
+ server.shutdown()
+
+
+@contextmanager
+@record
+def worker_main() -> Generator[None, None, None]:
+ """
+ This is a context manager that wraps your main entry function. This combines
+ the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that
+ exposes handlers via a unix socket specified by
+ ``Torch_WORKER_SERVER_SOCKET``.
+
+ Example
+
+ ::
+
+ @worker_main()
+ def main():
+ pass
+
+ if __name__=="__main__":
+ main()
+
+ """
+ with ExitStack() as stack:
+ socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET)
+ if socket_path is not None:
+ stack.enter_context(_worker_server(socket_path))
+
+ yield