WorkerServer: add support for binding to TCP (#127986)

This adds support for the WorkerServer binding to TCP as well as the existing unix socket support.

```py
server = _WorkerServer("", 1234)
```

Test plan:

Added unit test

```
python test/distributed/elastic/test_control_plane.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127986
Approved by: https://github.com/c-p-i-o
This commit is contained in:
Tristan Rice 2024-06-05 22:56:29 +00:00 committed by PyTorch MergeBot
parent a7c596870d
commit ffaea656b5
4 changed files with 39 additions and 19 deletions

View File

@ -81,6 +81,17 @@ class WorkerServerTest(TestCase):
self.assertEqual(resp.status, 200)
out = pickle.loads(resp.data)
def test_tcp(self) -> None:
import requests
from torch._C._distributed_c10d import _WorkerServer
server = _WorkerServer("", 1234)
out = requests.get("http://localhost:1234/handler/")
self.assertEqual(out.status_code, 200)
server.shutdown()
if __name__ == "__main__":
run_tests()

View File

@ -71,15 +71,7 @@ std::string jsonStrEscape(const std::string& str) {
}
} // namespace
WorkerServer::WorkerServer(const std::string& socketFile) {
// using unix sockets
server_.set_address_family(AF_UNIX);
// adjust keep alives as it stops the server from shutting down quickly
server_.set_keep_alive_timeout(1); // second, default is 5
server_.set_keep_alive_max_count(
30); // wait max 30 seconds before closing socket
WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
server_.Get("/", [](const httplib::Request& req, httplib::Response& res) {
res.set_content(
R"BODY(<h1>torch.distributed.WorkerServer</h1>
@ -139,13 +131,29 @@ WorkerServer::WorkerServer(const std::string& socketFile) {
}
});
if (std::filesystem::exists(socketFile)) {
throw std::runtime_error(fmt::format("{} already exists", socketFile));
}
// adjust keep alives as it stops the server from shutting down quickly
server_.set_keep_alive_timeout(1); // second, default is 5
server_.set_keep_alive_max_count(
30); // wait max 30 seconds before closing socket
C10D_WARNING("Server listening to {}", socketFile);
if (!server_.bind_to_port(socketFile, 80)) {
throw std::runtime_error(fmt::format("Error binding to {}", socketFile));
if (port == -1) {
// using unix sockets
server_.set_address_family(AF_UNIX);
if (std::filesystem::exists(hostOrFile)) {
throw std::runtime_error(fmt::format("{} already exists", hostOrFile));
}
C10D_WARNING("Server listening to UNIX {}", hostOrFile);
if (!server_.bind_to_port(hostOrFile, 80)) {
throw std::runtime_error(fmt::format("Error binding to {}", hostOrFile));
}
} else {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
if (!server_.bind_to_port(hostOrFile, port)) {
throw std::runtime_error(
fmt::format("Error binding to {}:{}", hostOrFile, port));
}
}
serverThread_ = std::thread([this]() {

View File

@ -14,7 +14,7 @@ namespace control_plane {
class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
public:
WorkerServer(const std::string& socketFile);
WorkerServer(const std::string& hostOrFile, int port = -1);
~WorkerServer();
void shutdown();

View File

@ -3170,11 +3170,12 @@ such as `dist.all_reduce(tensor, async_op=True)`.
module, "_WorkerServer", R"(
)")
.def(
py::init([](const std::string& socketPath) {
py::init([](const std::string& hostOrFile, int port) {
return c10::make_intrusive<::c10d::control_plane::WorkerServer>(
socketPath);
hostOrFile, port);
}),
py::arg("socket_path"))
py::arg("host_or_file"),
py::arg("port") = -1)
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
Py_RETURN_TRUE;
}