pytorch/test/distributed/elastic/utils/distributed_test.py
Kiuk Chung 5a2f41a2db [torch/distributed.elastic] Fix utils.distributed_test.test_create_store_timeout_on_server to be dual-stack ip compatible (#60558)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60558

Fixes 1/2 flaky tests as described in: https://github.com/pytorch/pytorch/issues/60260

`test_create_store_timeout_on_server` tests whether trying to create a `c10d::TCPStore` server on an already taken port actually fails with an `IOError`. Prior to this change the `utils.get_socket_with_port()` util method was used to synthetically reserve a port, then try creating the `TCPStore` on that port to validate the `IOError`. The issue with this is that on a dual stack ip setup, `get_socket_with_port()` (since it uses `socket.AF_UNSPEC`) reserves an ipv6 port, while `TCPStore` will try binding to an ipv4 port, so an `IOError` is not observed.

Changing the logic of the test to create two `TCPStore` servers. The first chooses a free port (by passing `server_port=0`) while the second tries to create a `TCPStore` server on the port that the first store is already running on. This would induce an `IOError` on the second store's constructor.

NOTE: this change does not solve another broader issue with `TCPStore` where the server and workers can listen and connect on ipv4 vs ipv6 when they are running on dual-stak ip hosts without ipv4 DNS entry and/or a `/etc/gai.conf` specifying the preferred bind ordering. See: https://github.com/pytorch/pytorch/pull/49124

Test Plan:
```
buck test //caffe2/test/distributed/elastic/utils:distributed_test
```

Reviewed By: cbalioglu

Differential Revision: D29334947

fbshipit-source-id: 76b998c59082cb04c0e86b7a1f3b509367fa0136
2021-06-23 17:12:18 -07:00

138 lines
4.4 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import multiprocessing as mp
import os
import socket
import unittest
from contextlib import closing
from torch.distributed.elastic.utils.distributed import (
create_c10d_store,
get_free_port,
get_socket_with_port,
)
from torch.testing._internal.common_utils import IS_MACOS, IS_WINDOWS, run_tests
def _create_c10d_store_mp(is_server, server_addr, port, world_size):
store = create_c10d_store(is_server, server_addr, port, world_size, timeout=2)
if store is None:
raise AssertionError()
store.set(f"test_key/{os.getpid()}", "test_value".encode("UTF-8"))
@unittest.skipIf(IS_WINDOWS or IS_MACOS, "tests incompatible with tsan or asan")
class DistributedUtilTest(unittest.TestCase):
def test_create_store_single_server(self):
store = create_c10d_store(is_server=True, server_addr=socket.gethostname())
self.assertIsNotNone(store)
def test_create_store_no_port_multi(self):
with self.assertRaises(ValueError):
create_c10d_store(
is_server=True, server_addr=socket.gethostname(), world_size=2
)
def test_create_store_multi(self):
world_size = 3
server_port = get_free_port()
localhost = socket.gethostname()
worker0 = mp.Process(
target=_create_c10d_store_mp,
args=(False, localhost, server_port, world_size),
)
worker1 = mp.Process(
target=_create_c10d_store_mp,
args=(False, localhost, server_port, world_size),
)
worker0.start()
worker1.start()
# start the server on the main process
store = create_c10d_store(
is_server=True,
server_addr=localhost,
server_port=server_port,
world_size=world_size,
timeout=2,
)
worker0.join()
worker1.join()
# check test_key/pid == "test_value"
self.assertEqual(
"test_value", store.get(f"test_key/{worker0.pid}").decode("UTF-8")
)
self.assertEqual(
"test_value", store.get(f"test_key/{worker1.pid}").decode("UTF-8")
)
self.assertEqual(0, worker0.exitcode)
self.assertEqual(0, worker1.exitcode)
def test_create_store_timeout_on_server(self):
with self.assertRaises(TimeoutError):
port = get_free_port()
create_c10d_store(
is_server=True,
server_addr=socket.gethostname(),
server_port=port,
world_size=2,
timeout=1,
)
def test_create_store_timeout_on_worker(self):
with self.assertRaises(TimeoutError):
port = get_free_port()
create_c10d_store(
is_server=False,
server_addr=socket.gethostname(),
server_port=port,
world_size=2,
timeout=1,
)
def test_port_already_in_use_on_server(self):
# try to create the TCPStore server twice on the same port
# the second should fail due to a port conflict
# first store binds onto a free port
# try creating the second store on the port that the first store binded to
server_addr = socket.gethostname()
pick_free_port = 0
store1 = create_c10d_store(
is_server=True,
server_addr=server_addr,
server_port=pick_free_port,
timeout=1,
)
with self.assertRaises(IOError):
create_c10d_store(
is_server=True, server_addr=server_addr, server_port=store1.port
)
def test_port_already_in_use_on_worker(self):
sock = get_socket_with_port()
with closing(sock):
port = sock.getsockname()[1]
# on the worker port conflict shouldn't matter, it should just timeout
# since we never created a server
with self.assertRaises(IOError):
create_c10d_store(
is_server=False,
server_addr=socket.gethostname(),
server_port=port,
timeout=1,
)
if __name__ == "__main__":
run_tests()