mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
WorkerServer: add support for binding to TCP (#127986)
This adds support for the WorkerServer binding to TCP as well as the existing unix socket support.
```py
server = _WorkerServer("", 1234)
```
Test plan:
Added unit test
```
python test/distributed/elastic/test_control_plane.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127986
Approved by: https://github.com/c-p-i-o
This commit is contained in:
parent
a7c596870d
commit
ffaea656b5
|
|
@ -81,6 +81,17 @@ class WorkerServerTest(TestCase):
|
|||
self.assertEqual(resp.status, 200)
|
||||
out = pickle.loads(resp.data)
|
||||
|
||||
def test_tcp(self) -> None:
|
||||
import requests
|
||||
|
||||
from torch._C._distributed_c10d import _WorkerServer
|
||||
|
||||
server = _WorkerServer("", 1234)
|
||||
out = requests.get("http://localhost:1234/handler/")
|
||||
self.assertEqual(out.status_code, 200)
|
||||
|
||||
server.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -71,15 +71,7 @@ std::string jsonStrEscape(const std::string& str) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
WorkerServer::WorkerServer(const std::string& socketFile) {
|
||||
// using unix sockets
|
||||
server_.set_address_family(AF_UNIX);
|
||||
|
||||
// adjust keep alives as it stops the server from shutting down quickly
|
||||
server_.set_keep_alive_timeout(1); // second, default is 5
|
||||
server_.set_keep_alive_max_count(
|
||||
30); // wait max 30 seconds before closing socket
|
||||
|
||||
WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
|
||||
server_.Get("/", [](const httplib::Request& req, httplib::Response& res) {
|
||||
res.set_content(
|
||||
R"BODY(<h1>torch.distributed.WorkerServer</h1>
|
||||
|
|
@ -139,13 +131,29 @@ WorkerServer::WorkerServer(const std::string& socketFile) {
|
|||
}
|
||||
});
|
||||
|
||||
if (std::filesystem::exists(socketFile)) {
|
||||
throw std::runtime_error(fmt::format("{} already exists", socketFile));
|
||||
}
|
||||
// adjust keep alives as it stops the server from shutting down quickly
|
||||
server_.set_keep_alive_timeout(1); // second, default is 5
|
||||
server_.set_keep_alive_max_count(
|
||||
30); // wait max 30 seconds before closing socket
|
||||
|
||||
C10D_WARNING("Server listening to {}", socketFile);
|
||||
if (!server_.bind_to_port(socketFile, 80)) {
|
||||
throw std::runtime_error(fmt::format("Error binding to {}", socketFile));
|
||||
if (port == -1) {
|
||||
// using unix sockets
|
||||
server_.set_address_family(AF_UNIX);
|
||||
|
||||
if (std::filesystem::exists(hostOrFile)) {
|
||||
throw std::runtime_error(fmt::format("{} already exists", hostOrFile));
|
||||
}
|
||||
|
||||
C10D_WARNING("Server listening to UNIX {}", hostOrFile);
|
||||
if (!server_.bind_to_port(hostOrFile, 80)) {
|
||||
throw std::runtime_error(fmt::format("Error binding to {}", hostOrFile));
|
||||
}
|
||||
} else {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
if (!server_.bind_to_port(hostOrFile, port)) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
}
|
||||
}
|
||||
|
||||
serverThread_ = std::thread([this]() {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ namespace control_plane {
|
|||
|
||||
class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
WorkerServer(const std::string& socketFile);
|
||||
WorkerServer(const std::string& hostOrFile, int port = -1);
|
||||
~WorkerServer();
|
||||
|
||||
void shutdown();
|
||||
|
|
|
|||
|
|
@ -3170,11 +3170,12 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
|||
module, "_WorkerServer", R"(
|
||||
)")
|
||||
.def(
|
||||
py::init([](const std::string& socketPath) {
|
||||
py::init([](const std::string& hostOrFile, int port) {
|
||||
return c10::make_intrusive<::c10d::control_plane::WorkerServer>(
|
||||
socketPath);
|
||||
hostOrFile, port);
|
||||
}),
|
||||
py::arg("socket_path"))
|
||||
py::arg("host_or_file"),
|
||||
py::arg("port") = -1)
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user