pytorch/test/distributed/elastic/utils/distributed_test.py
Kiuk Chung ba75cedfc5 [1/n][torch/elastic][upstream] Move torchelastic/rendezvous to torch/distributed/rendezvous (#53172)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53172

Pull Request resolved: https://github.com/pytorch/elastic/pull/141

Upstreams two modules to torch:

1. `torchelastic.rendezvous`
2. `torchelastic.utils`

These modules were chosen as `[1/n]` since they are the leaf modules in torchelastic.

==== NOTES: ====
1. I'm disabling etcd_rendezvous and etcd_server tests in CIRCLECI for the moment since I need to edit the test dockers to contain the etcd server binary (there's 4-5 test dockers - one for each platform so this is going to take some time for me to set up the environments and test) - T85992919.

2. I've fixed all lint errors on python files but there are ones on the cpp files on the ZeusRendezvous. I took a look at them, and I don't want to fix the linter errors right now for 2 major reasons:
     1. Some of them are more than formatting changes (e.g. std::move vs pass by value) and I don't want to introduce bundled changes with the move
     1. The old rendezvous code (the one we forked from in caffe2/fb) has the same problems and I think its better for us to deal with this when we deprecate caffe2/fb/rendezvous in favor of the one in torchelastic -T86012579.

Test Plan:
```
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/test/...
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/data/test/...
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/test/...
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/fb/...
buck test mode/dev-nosan //pytorch/elastic/torchelastic/...
```
\+ Sandcastle

Reviewed By: H-Huang

Differential Revision: D26718746

fbshipit-source-id: 67cc0350c3d847221cb3c3038f98f47915362f51
2021-03-05 11:27:57 -08:00

128 lines
4.0 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import multiprocessing as mp
import os
import socket
import unittest
from contextlib import closing
from torch.distributed.elastic.utils.distributed import (
create_c10d_store,
get_free_port,
get_socket_with_port,
)
def _create_c10d_store_mp(is_server, server_addr, port, world_size):
store = create_c10d_store(is_server, server_addr, port, world_size, timeout=2)
if store is None:
raise AssertionError()
store.set(f"test_key/{os.getpid()}", "test_value".encode("UTF-8"))
class DistributedUtilTest(unittest.TestCase):
def test_create_store_single_server(self):
store = create_c10d_store(is_server=True, server_addr=socket.gethostname())
self.assertIsNotNone(store)
def test_create_store_no_port_multi(self):
with self.assertRaises(ValueError):
create_c10d_store(
is_server=True, server_addr=socket.gethostname(), world_size=2
)
def test_create_store_multi(self):
world_size = 3
server_port = get_free_port()
localhost = socket.gethostname()
worker0 = mp.Process(
target=_create_c10d_store_mp,
args=(False, localhost, server_port, world_size),
)
worker1 = mp.Process(
target=_create_c10d_store_mp,
args=(False, localhost, server_port, world_size),
)
worker0.start()
worker1.start()
# start the server on the main process
store = create_c10d_store(
is_server=True,
server_addr=localhost,
server_port=server_port,
world_size=world_size,
timeout=2,
)
worker0.join()
worker1.join()
# check test_key/pid == "test_value"
self.assertEqual(
"test_value", store.get(f"test_key/{worker0.pid}").decode("UTF-8")
)
self.assertEqual(
"test_value", store.get(f"test_key/{worker1.pid}").decode("UTF-8")
)
self.assertEqual(0, worker0.exitcode)
self.assertEqual(0, worker1.exitcode)
def test_create_store_timeout_on_server(self):
with self.assertRaises(TimeoutError):
port = get_free_port()
create_c10d_store(
is_server=True,
server_addr=socket.gethostname(),
server_port=port,
world_size=2,
timeout=1,
)
def test_create_store_timeout_on_worker(self):
with self.assertRaises(TimeoutError):
port = get_free_port()
create_c10d_store(
is_server=False,
server_addr=socket.gethostname(),
server_port=port,
world_size=2,
timeout=1,
)
def test_port_already_in_use_on_server(self):
sock = get_socket_with_port()
with closing(sock):
# try to create a store on the same port without releasing the socket
# should raise a IOError
port = sock.getsockname()[1]
with self.assertRaises(IOError):
create_c10d_store(
is_server=True,
server_addr=socket.gethostname(),
server_port=port,
timeout=1,
)
def test_port_already_in_use_on_worker(self):
sock = get_socket_with_port()
with closing(sock):
port = sock.getsockname()[1]
# on the worker port conflict shouldn't matter, it should just timeout
# since we never created a server
with self.assertRaises(IOError):
create_c10d_store(
is_server=False,
server_addr=socket.gethostname(),
server_port=port,
timeout=1,
)