mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53172 Pull Request resolved: https://github.com/pytorch/elastic/pull/141 Upstreams two modules to torch: 1. `torchelastic.rendezvous` 2. `torchelastic.utils` These modules were chosen as `[1/n]` since they are the leaf modules in torchelastic. ==== NOTES: ==== 1. I'm disabling etcd_rendezvous and etcd_server tests in CIRCLECI for the moment since I need to edit the test dockers to contain the etcd server binary (there's 4-5 test dockers - one for each platform so this is going to take some time for me to set up the environments and test) - T85992919. 2. I've fixed all lint errors on python files but there are ones on the cpp files on the ZeusRendezvous. I took a look at them, and I don't want to fix the linter errors right now for 2 major reasons: 1. Some of them are more than formatting changes (e.g. std::move vs pass by value) and I don't want to introduce bundled changes with the move 1. The old rendezvous code (the one we forked from in caffe2/fb) has the same problems and I think its better for us to deal with this when we deprecate caffe2/fb/rendezvous in favor of the one in torchelastic -T86012579. Test Plan: ``` buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/test/... buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/data/test/... buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/test/... buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/fb/... buck test mode/dev-nosan //pytorch/elastic/torchelastic/... ``` \+ Sandcastle Reviewed By: H-Huang Differential Revision: D26718746 fbshipit-source-id: 67cc0350c3d847221cb3c3038f98f47915362f51
128 lines
4.0 KiB
Python
128 lines
4.0 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import multiprocessing as mp
|
|
import os
|
|
import socket
|
|
import unittest
|
|
from contextlib import closing
|
|
|
|
from torch.distributed.elastic.utils.distributed import (
|
|
create_c10d_store,
|
|
get_free_port,
|
|
get_socket_with_port,
|
|
)
|
|
|
|
|
|
def _create_c10d_store_mp(is_server, server_addr, port, world_size):
|
|
store = create_c10d_store(is_server, server_addr, port, world_size, timeout=2)
|
|
if store is None:
|
|
raise AssertionError()
|
|
|
|
store.set(f"test_key/{os.getpid()}", "test_value".encode("UTF-8"))
|
|
|
|
|
|
class DistributedUtilTest(unittest.TestCase):
|
|
def test_create_store_single_server(self):
|
|
store = create_c10d_store(is_server=True, server_addr=socket.gethostname())
|
|
self.assertIsNotNone(store)
|
|
|
|
def test_create_store_no_port_multi(self):
|
|
with self.assertRaises(ValueError):
|
|
create_c10d_store(
|
|
is_server=True, server_addr=socket.gethostname(), world_size=2
|
|
)
|
|
|
|
def test_create_store_multi(self):
|
|
world_size = 3
|
|
server_port = get_free_port()
|
|
localhost = socket.gethostname()
|
|
worker0 = mp.Process(
|
|
target=_create_c10d_store_mp,
|
|
args=(False, localhost, server_port, world_size),
|
|
)
|
|
worker1 = mp.Process(
|
|
target=_create_c10d_store_mp,
|
|
args=(False, localhost, server_port, world_size),
|
|
)
|
|
|
|
worker0.start()
|
|
worker1.start()
|
|
|
|
# start the server on the main process
|
|
store = create_c10d_store(
|
|
is_server=True,
|
|
server_addr=localhost,
|
|
server_port=server_port,
|
|
world_size=world_size,
|
|
timeout=2,
|
|
)
|
|
|
|
worker0.join()
|
|
worker1.join()
|
|
|
|
# check test_key/pid == "test_value"
|
|
self.assertEqual(
|
|
"test_value", store.get(f"test_key/{worker0.pid}").decode("UTF-8")
|
|
)
|
|
self.assertEqual(
|
|
"test_value", store.get(f"test_key/{worker1.pid}").decode("UTF-8")
|
|
)
|
|
|
|
self.assertEqual(0, worker0.exitcode)
|
|
self.assertEqual(0, worker1.exitcode)
|
|
|
|
def test_create_store_timeout_on_server(self):
|
|
with self.assertRaises(TimeoutError):
|
|
port = get_free_port()
|
|
create_c10d_store(
|
|
is_server=True,
|
|
server_addr=socket.gethostname(),
|
|
server_port=port,
|
|
world_size=2,
|
|
timeout=1,
|
|
)
|
|
|
|
def test_create_store_timeout_on_worker(self):
|
|
with self.assertRaises(TimeoutError):
|
|
port = get_free_port()
|
|
create_c10d_store(
|
|
is_server=False,
|
|
server_addr=socket.gethostname(),
|
|
server_port=port,
|
|
world_size=2,
|
|
timeout=1,
|
|
)
|
|
|
|
def test_port_already_in_use_on_server(self):
|
|
sock = get_socket_with_port()
|
|
with closing(sock):
|
|
# try to create a store on the same port without releasing the socket
|
|
# should raise a IOError
|
|
port = sock.getsockname()[1]
|
|
with self.assertRaises(IOError):
|
|
create_c10d_store(
|
|
is_server=True,
|
|
server_addr=socket.gethostname(),
|
|
server_port=port,
|
|
timeout=1,
|
|
)
|
|
|
|
def test_port_already_in_use_on_worker(self):
|
|
sock = get_socket_with_port()
|
|
with closing(sock):
|
|
port = sock.getsockname()[1]
|
|
# on the worker port conflict shouldn't matter, it should just timeout
|
|
# since we never created a server
|
|
with self.assertRaises(IOError):
|
|
create_c10d_store(
|
|
is_server=False,
|
|
server_addr=socket.gethostname(),
|
|
server_port=port,
|
|
timeout=1,
|
|
)
|