diff --git a/BUILD.bazel b/BUILD.bazel index 71ebc296598..b58fb57199f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -769,6 +769,7 @@ cc_library( ":caffe2", ":torch_headers", "@kineto", + "@cpp-httplib", ] + if_cuda([ "@cuda//:nvToolsExt", "@cutlass", diff --git a/WORKSPACE b/WORKSPACE index 5b4f2f2e337..4169e0dbce1 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -168,6 +168,12 @@ new_local_repository( path = "third_party/opentelemetry-cpp", ) +new_local_repository( + name = "cpp-httplib", + build_file = "//third_party:cpp-httplib.BUILD", + path = "third_party/cpp-httplib", +) + new_local_repository( name = "tensorpipe", build_file = "//third_party:tensorpipe.BUILD", diff --git a/build_variables.bzl b/build_variables.bzl index 8b5ac4f46d7..20822ba95cf 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -515,6 +515,8 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/sequence_num.cpp", "torch/csrc/distributed/c10d/socket.cpp", "torch/csrc/distributed/c10d/Work.cpp", + "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", + "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", ] # These files are only supported on Linux (and others) but not on Windows. diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index e9b2b20ce6a..fe24571b66a 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1179,6 +1179,9 @@ if(USE_KINETO) ${TORCH_ROOT}/third_party/kineto/libkineto/src) endif() +target_include_directories(torch_cpu PRIVATE + ${TORCH_ROOT}/third_party/cpp-httplib) + install(DIRECTORY "${TORCH_SRC_DIR}/csrc" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp") diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 8c7751f4c07..9693ac6e9fe 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1681,3 +1681,7 @@ endif() # Include google/FlatBuffers include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake) + +# Include cpp-httplib +add_library(httplib INTERFACE IMPORTED) +target_include_directories(httplib SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/cpp-httplib) diff --git a/docs/source/distributed.elastic.rst b/docs/source/distributed.elastic.rst index 24d33d1982d..0aabb560c9c 100644 --- a/docs/source/distributed.elastic.rst +++ b/docs/source/distributed.elastic.rst @@ -29,6 +29,7 @@ Documentation elastic/metrics elastic/events elastic/subprocess_handler + elastic/control_plane .. toctree:: :maxdepth: 1 diff --git a/docs/source/elastic/control_plane.rst b/docs/source/elastic/control_plane.rst new file mode 100644 index 00000000000..c37454cf1b0 --- /dev/null +++ b/docs/source/elastic/control_plane.rst @@ -0,0 +1,10 @@ +Control Plane +============= + +.. automodule:: torch.distributed.elastic.control_plane +.. currentmodule:: torch.distributed.elastic.control_plane + +This module contains optional helpers that add extra debug and control handlers +into your application. + +.. autofunction:: torch.distributed.elastic.control_plane.worker_main diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py new file mode 100644 index 00000000000..c9ae512f271 --- /dev/null +++ b/test/distributed/elastic/test_control_plane.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Owner(s): ["oncall: distributed"] + +import json +import os +import pickle +import socket +import tempfile +from contextlib import contextmanager + +from urllib3.connection import HTTPConnection +from urllib3.connectionpool import HTTPConnectionPool + +from torch.distributed.elastic.control_plane import ( + TORCH_WORKER_SERVER_SOCKET, + worker_main, +) +from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase + + +class UnixHTTPConnection(HTTPConnection): + def __init__(self, socket_path: str) -> None: + super().__init__("localhost") + + self.socket_path = socket_path + + def connect(self) -> None: + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.connect(self.socket_path) + + +class UnixHTTPConnectionPool(HTTPConnectionPool): + def __init__(self, socket_path: str) -> None: + super().__init__("localhost") + + self.socket_path = socket_path + + def _new_conn(self): + return UnixHTTPConnection(self.socket_path) + + +@contextmanager +def local_worker_server() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "socket.sock") + os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path + + with worker_main(): + pool = UnixHTTPConnectionPool(socket_path) + yield pool + + +class WorkerServerTest(TestCase): + def test_worker_server(self) -> None: + with local_worker_server() as pool: + resp = pool.request("GET", "/") + self.assertEqual(resp.status, 200) + self.assertEqual( + resp.data, + b"""

torch.distributed.WorkerServer

+Handler names +""", + ) + + resp = pool.request("POST", "/handler/ping") + self.assertEqual(resp.status, 200) + self.assertEqual(resp.data, b"pong") + + resp = pool.request("GET", "/handler/") + self.assertEqual(resp.status, 200) + self.assertIn("ping", json.loads(resp.data)) + + resp = pool.request("POST", "/handler/nonexistant") + self.assertEqual(resp.status, 404) + self.assertIn(b"Handler nonexistant not found:", resp.data) + + @requires_cuda + def test_dump_nccl_trace_pickle(self) -> None: + with local_worker_server() as pool: + resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") + self.assertEqual(resp.status, 200) + out = pickle.loads(resp.data) + + +if __name__ == "__main__": + run_tests() diff --git a/third_party/cpp-httplib.BUILD b/third_party/cpp-httplib.BUILD new file mode 100644 index 00000000000..3cd0c3dbe94 --- /dev/null +++ b/third_party/cpp-httplib.BUILD @@ -0,0 +1,10 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "cpp-httplib", + hdrs = ["httplib.h"], + includes = [ + "/", + ], + visibility = ["//visibility:public"], +) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index fa62688c7e8..10a44af747b 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -68,6 +68,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES ${TORCH_ROOT}/third_party/onnx ${TORCH_ROOT}/third_party/flatbuffers/include ${TORCH_ROOT}/third_party/kineto/libkineto/include + ${TORCH_ROOT}/third_party/cpp-httplib ${TORCH_SRC_DIR}/csrc ${TORCH_SRC_DIR}/csrc/api/include @@ -80,6 +81,7 @@ set(TORCH_PYTHON_LINK_LIBRARIES Python::Module pybind::pybind11 opentelemetry::api + httplib shm fmt::fmt-header-only ATEN_CPU_FILES_GEN_LIB) diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 1a3e4ea6334..dab215d396c 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -94,6 +94,10 @@ class Logger: def _set_uneven_input_join(self) -> None: ... def _set_static_graph(self) -> None: ... +class _WorkerServer: + def __init__(self, socket_path: str) -> None: ... + def shutdown(self) -> None: ... + def get_debug_level(): ... def set_debug_level(): ... def set_debug_level_from_env(): ... diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 1c0bdc43be3..2e55bfdb6f3 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -369,6 +370,13 @@ std::string dump_nccl_trace() { } #endif +// TODO(c-p-i-o): add a JSON endpoint. +control_plane::RegisterHandler dumpHandler{ + "dump_nccl_trace_pickle", + [](const control_plane::Request&, control_plane::Response& res) { + res.setContent(dump_nccl_trace(), "application/octet-stream"); + }}; + std::optional)>>& get_cpp_trace_dumper() { static std::optional< diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp new file mode 100644 index 00000000000..e29f1e3a2ac --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -0,0 +1,75 @@ +#include + +#include +#include +#include +#include + +namespace c10d { +namespace control_plane { + +namespace { + +class HandlerRegistry { + public: + void registerHandler(const std::string& name, HandlerFunc f) { + std::unique_lock lock(handlersMutex_); + + if (handlers_.find(name) != handlers_.end()) { + throw std::runtime_error( + fmt::format("Handler {} already registered", name)); + } + + handlers_[name] = f; + } + + HandlerFunc getHandler(const std::string& name) { + std::shared_lock lock(handlersMutex_); + + auto it = handlers_.find(name); + if (it == handlers_.end()) { + throw std::runtime_error(fmt::format("Failed to find handler {}", name)); + } + return handlers_[name]; + } + + std::vector getHandlerNames() { + std::shared_lock lock(handlersMutex_); + + std::vector names; + for (const auto& [name, _] : handlers_) { + names.push_back(name); + } + return names; + } + + private: + std::shared_mutex handlersMutex_{}; + std::unordered_map handlers_{}; +}; + +HandlerRegistry& getHandlerRegistry() { + static HandlerRegistry registry; + return registry; +} + +RegisterHandler pingHandler{"ping", [](const Request&, Response& res) { + res.setContent("pong", "text/plain"); + }}; + +} // namespace + +void registerHandler(const std::string& name, HandlerFunc f) { + return getHandlerRegistry().registerHandler(name, f); +} + +HandlerFunc getHandler(const std::string& name) { + return getHandlerRegistry().getHandler(name); +} + +std::vector getHandlerNames() { + return getHandlerRegistry().getHandlerNames(); +} + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp new file mode 100644 index 00000000000..0c106305493 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include + +namespace c10d { +namespace control_plane { + +// Request represents a request to the handler. This conceptually maps to an +// HTTP request but could be called via other transports. +class TORCH_API Request { + public: + virtual ~Request() = default; + + virtual const std::string& body() = 0; +}; + +// Response represents a response to the handler. This conceptually maps to an +// HTTP response but could be called via other transports. +class TORCH_API Response { + public: + virtual ~Response() = default; + + // Set the response body to the provided string. + // TODO: add support for chunked responses + virtual void setContent( + std::string&& content, + const std::string& content_type) = 0; + + // Set the response status code. + // These should match standard HTTP status codes. + virtual void setStatus(int status) = 0; +}; + +using HandlerFunc = std::function; + +// Registers a handler. The name needs to be unique and can be called by using +// getHandler directly or via WorkerServer for remote requests. +// These handlers are called from a background C++ thread concurrently with the +// main thread. These handlers need to be thread safe and not cause issues +// during Python training. +TORCH_API void registerHandler(const std::string& name, HandlerFunc f); + +// Fetches a handler by name. +TORCH_API HandlerFunc getHandler(const std::string& name); + +TORCH_API std::vector getHandlerNames(); + +// Registers a handler statically. +// See registerHandler for more details. +class TORCH_API RegisterHandler { + public: + RegisterHandler(const std::string& name, HandlerFunc f) { + registerHandler(name, f); + } + + // disable move, copy + RegisterHandler(const RegisterHandler&) = delete; + RegisterHandler(RegisterHandler&&) = delete; + RegisterHandler& operator=(const RegisterHandler&) = delete; + RegisterHandler& operator=(RegisterHandler&&) = delete; +}; + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp new file mode 100644 index 00000000000..14d287e9607 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -0,0 +1,178 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10d { +namespace control_plane { + +namespace { +class RequestImpl : public Request { + public: + RequestImpl(const httplib::Request& req) : req_(req) {} + + const std::string& body() override { + return req_.body; + } + + private: + const httplib::Request& req_; +}; + +class ResponseImpl : public Response { + public: + ResponseImpl(httplib::Response& res) : res_(res) {} + + void setStatus(int status) override { + res_.status = status; + } + + void setContent(std::string&& content, const std::string& content_type) + override { + res_.set_content(std::move(content), content_type); + } + + private: + httplib::Response& res_; +}; + +std::string jsonStrEscape(const std::string& str) { + std::ostringstream ostream; + for (char ch : str) { + if (ch == '"') { + ostream << "\\\""; + } else if (ch == '\\') { + ostream << "\\\\"; + } else if (ch == '\b') { + ostream << "\\b"; + } else if (ch == '\f') { + ostream << "\\f"; + } else if (ch == '\n') { + ostream << "\\n"; + } else if (ch == '\r') { + ostream << "\\r"; + } else if (ch == '\t') { + ostream << "\\t"; + } else if ('\x00' <= ch && ch <= '\x1f') { + ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0') + << static_cast(ch); + } else { + ostream << ch; + } + } + return ostream.str(); +} +} // namespace + +WorkerServer::WorkerServer(const std::string& socketFile) { + // using unix sockets + server_.set_address_family(AF_UNIX); + + // adjust keep alives as it stops the server from shutting down quickly + server_.set_keep_alive_timeout(1); // second, default is 5 + server_.set_keep_alive_max_count( + 30); // wait max 30 seconds before closing socket + + server_.Get("/", [](const httplib::Request& req, httplib::Response& res) { + res.set_content( + R"BODY(

torch.distributed.WorkerServer

+Handler names +)BODY", + "text/html"); + }); + server_.Get( + "/handler/", [](const httplib::Request& req, httplib::Response& res) { + std::ostringstream body; + body << "["; + bool first = true; + for (const auto& name : getHandlerNames()) { + if (!first) { + body << ","; + } + first = false; + + body << "\"" << jsonStrEscape(name) << "\""; + } + body << "]"; + + res.set_content(body.str(), "application/json"); + }); + server_.Post( + "/handler/:handler", + [](const httplib::Request& req, httplib::Response& res) { + auto handler_name = req.path_params.at("handler"); + HandlerFunc handler; + try { + handler = getHandler(handler_name); + } catch (const std::exception& e) { + res.status = 404; + res.set_content( + fmt::format("Handler {} not found: {}", handler_name, e.what()), + "text/plain"); + return; + } + RequestImpl torchReq{req}; + ResponseImpl torchRes{res}; + + try { + handler(torchReq, torchRes); + } catch (const std::exception& e) { + res.status = 500; + res.set_content( + fmt::format("Handler {} failed: {}", handler_name, e.what()), + "text/plain"); + return; + } catch (...) { + res.status = 500; + res.set_content( + fmt::format( + "Handler {} failed with unknown exception", handler_name), + "text/plain"); + return; + } + }); + + if (std::filesystem::exists(socketFile)) { + throw std::runtime_error(fmt::format("{} already exists", socketFile)); + } + + C10D_WARNING("Server listening to {}", socketFile); + if (!server_.bind_to_port(socketFile, 80)) { + throw std::runtime_error(fmt::format("Error binding to {}", socketFile)); + } + + serverThread_ = std::thread([this]() { + try { + if (!server_.listen_after_bind()) { + throw std::runtime_error("failed to listen"); + } + } catch (std::exception& e) { + C10D_ERROR("Error while running server: {}", e.what()); + throw; + } + C10D_WARNING("Server exited"); + }); +} + +void WorkerServer::shutdown() { + C10D_WARNING("Server shutting down"); + server_.stop(); + serverThread_.join(); +} + +WorkerServer::~WorkerServer() { + if (serverThread_.joinable()) { + C10D_WARNING("WorkerServer destructor called without shutdown"); + shutdown(); + } +} + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp new file mode 100644 index 00000000000..7d64038f0b0 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace c10d { +namespace control_plane { + +class TORCH_API WorkerServer : public c10::intrusive_ptr_target { + public: + WorkerServer(const std::string& socketFile); + ~WorkerServer(); + + void shutdown(); + + private: + httplib::Server server_; + std::thread serverThread_; +}; + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f6dae32606..c4b9a9823c8 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #ifndef _WIN32 #include @@ -3164,6 +3165,17 @@ such as `dist.all_reduce(tensor, async_op=True)`. return py::bytes(::c10d::dump_nccl_trace()); }); #endif + + intrusive_ptr_class_<::c10d::control_plane::WorkerServer>( + module, "_WorkerServer", R"( +)") + .def( + py::init([](const std::string& socketPath) { + return c10::make_intrusive<::c10d::control_plane::WorkerServer>( + socketPath); + }), + py::arg("socket_path")) + .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown); Py_RETURN_TRUE; } diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py new file mode 100644 index 00000000000..16038363786 --- /dev/null +++ b/torch/distributed/elastic/control_plane.py @@ -0,0 +1,51 @@ +import os +from contextlib import contextmanager, ExitStack +from typing import Generator + +from torch.distributed.elastic.multiprocessing.errors import record + +__all__ = [ + "worker_main", +] + +TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" + + +@contextmanager +def _worker_server(socket_path: str) -> Generator[None, None, None]: + from torch._C._distributed_c10d import _WorkerServer + + server = _WorkerServer(socket_path) + try: + yield + finally: + server.shutdown() + + +@contextmanager +@record +def worker_main() -> Generator[None, None, None]: + """ + This is a context manager that wraps your main entry function. This combines + the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that + exposes handlers via a unix socket specified by + ``Torch_WORKER_SERVER_SOCKET``. + + Example + + :: + + @worker_main() + def main(): + pass + + if __name__=="__main__": + main() + + """ + with ExitStack() as stack: + socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) + if socket_path is not None: + stack.enter_context(_worker_server(socket_path)) + + yield