pytorch/test/distributed/elastic/test_control_plane.py
Tristan Rice ffaea656b5 WorkerServer: add support for binding to TCP (#127986)
This adds support for the WorkerServer binding to TCP as well as the existing unix socket support.

```py
server = _WorkerServer("", 1234)
```

Test plan:

Added unit test

```
python test/distributed/elastic/test_control_plane.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127986
Approved by: https://github.com/c-p-i-o
2024-06-05 22:56:32 +00:00

98 lines
2.8 KiB
Python

#!/usr/bin/env python3
# Owner(s): ["oncall: distributed"]
import json
import os
import pickle
import socket
import tempfile
from contextlib import contextmanager
from urllib3.connection import HTTPConnection
from urllib3.connectionpool import HTTPConnectionPool
from torch.distributed.elastic.control_plane import (
TORCH_WORKER_SERVER_SOCKET,
worker_main,
)
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
class UnixHTTPConnection(HTTPConnection):
def __init__(self, socket_path: str) -> None:
super().__init__("localhost")
self.socket_path = socket_path
def connect(self) -> None:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(self.socket_path)
class UnixHTTPConnectionPool(HTTPConnectionPool):
def __init__(self, socket_path: str) -> None:
super().__init__("localhost")
self.socket_path = socket_path
def _new_conn(self):
return UnixHTTPConnection(self.socket_path)
@contextmanager
def local_worker_server() -> None:
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "socket.sock")
os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path
with worker_main():
pool = UnixHTTPConnectionPool(socket_path)
yield pool
class WorkerServerTest(TestCase):
def test_worker_server(self) -> None:
with local_worker_server() as pool:
resp = pool.request("GET", "/")
self.assertEqual(resp.status, 200)
self.assertEqual(
resp.data,
b"""<h1>torch.distributed.WorkerServer</h1>
<a href="/handler/">Handler names</a>
""",
)
resp = pool.request("POST", "/handler/ping")
self.assertEqual(resp.status, 200)
self.assertEqual(resp.data, b"pong")
resp = pool.request("GET", "/handler/")
self.assertEqual(resp.status, 200)
self.assertIn("ping", json.loads(resp.data))
resp = pool.request("POST", "/handler/nonexistant")
self.assertEqual(resp.status, 404)
self.assertIn(b"Handler nonexistant not found:", resp.data)
@requires_cuda
def test_dump_nccl_trace_pickle(self) -> None:
with local_worker_server() as pool:
resp = pool.request("POST", "/handler/dump_nccl_trace_pickle")
self.assertEqual(resp.status, 200)
out = pickle.loads(resp.data)
def test_tcp(self) -> None:
import requests
from torch._C._distributed_c10d import _WorkerServer
server = _WorkerServer("", 1234)
out = requests.get("http://localhost:1234/handler/")
self.assertEqual(out.status_code, 200)
server.shutdown()
if __name__ == "__main__":
run_tests()