# Owner(s): ["oncall: distributed"] import os import sys import tempfile import time from datetime import timedelta from sys import platform import torch import torch.distributed as dist import torch.distributed.rpc as rpc if not dist.is_available(): print("torch.distributed not available, skipping tests", file=sys.stderr) sys.exit(0) import torch.testing._internal.common_utils as common from torch._six import string_classes from torch.testing._internal.common_distributed import ( skip_if_win32, create_tcp_store ) from torch.testing._internal.common_utils import ( TestCase, load_tests, run_tests, retry_on_connect_failures, ADDRESS_IN_USE, CONNECT_TIMEOUT, ) # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests if platform == "darwin": LOOPBACK = "lo0" else: LOOPBACK = "lo" DEFAULT_HOSTNAME = "localhost" torch.backends.cuda.matmul.allow_tf32 = False def gpus_for_rank(world_size): """Multigpu tests are designed to simulate the multi nodes with multi GPUs on each node. Nccl backend requires equal #GPUs in each process. On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ visible_devices = list(range(torch.cuda.device_count())) gpus_per_process = torch.cuda.device_count() // world_size gpus_for_rank = [] for rank in range(world_size): gpus_for_rank.append( visible_devices[rank * gpus_per_process: (rank + 1) * gpus_per_process] ) return gpus_for_rank class StoreTestBase(object): def _create_store(self, i): raise RuntimeError("not implemented") def _test_set_get(self, fs): fs.add("key", 1) fs.add("key", 2) fs.add("key", 3) fs.set("key0", "value0") fs.add("key3", 1) fs.set("key1", "value1") fs.add("key3", 2) fs.set("key2", "value2") fs.add("key3", 3) fs.add("key3", 4) fs.add("key3", 5) fs.add("key3", 6) self.assertEqual(fs.num_keys(), self.num_keys_total) self.assertEqual(b"6", fs.get("key")) self.assertEqual(b"value0", fs.get("key0")) self.assertEqual(b"value1", fs.get("key1")) self.assertEqual(b"value2", fs.get("key2")) self.assertEqual(b"21", fs.get("key3")) fs.set("-key3", "7") self.assertEqual(b"7", fs.get("-key3")) fs.delete_key("-key3") self.assertEqual(fs.num_keys(), self.num_keys_total) def test_set_get(self): self._test_set_get(self._create_store()) def _test_compare_set(self, store): missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0") self.assertEqual(b"wrong_old_value", missing_key_result) store.set("cs_key0", "value0") self.assertEqual(b"value0", store.get("cs_key0")) old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0") self.assertEqual(b"value0", old_value_result) self.assertEqual(b"value0", store.get("cs_key0")) new_value_result = store.compare_set("cs_key0", "value0", "new_value0") self.assertEqual(b"new_value0", new_value_result) self.assertEqual(b"new_value0", store.get("cs_key0")) empty_old_value_result = store.compare_set("cs_key1", "", "new_value1") self.assertEqual(b"new_value1", empty_old_value_result) self.assertEqual(b"new_value1", store.get("cs_key1")) def test_compare_set(self): self._test_compare_set(self._create_store()) # This is the number of keys used in test_set_get. Adding this as a class # property instead of hardcoding in the test since some Store # implementations will have differing number of keys. In the base case, # there will be 5 keys: key, key0, key1, key2, key3. @property def num_keys_total(self): return 5 class FileStoreTest(TestCase, StoreTestBase): def setUp(self): super(FileStoreTest, self).setUp() self.file = tempfile.NamedTemporaryFile(delete=False) def _create_store(self): store = dist.FileStore(self.file.name, 1) store.set_timeout(timedelta(seconds=300)) return store @skip_if_win32() class HashStoreTest(TestCase, StoreTestBase): def setUp(self): super(HashStoreTest, self).setUp() def _create_store(self): store = dist.HashStore() store.set_timeout(timedelta(seconds=300)) return store class PrefixFileStoreTest(TestCase, StoreTestBase): def setUp(self): super(PrefixFileStoreTest, self).setUp() self.file = tempfile.NamedTemporaryFile(delete=False) self.filestore = dist.FileStore(self.file.name, 1) self.prefix = "test_prefix" self.filestore.set_timeout(timedelta(seconds=300)) def _create_store(self): return dist.PrefixStore(self.prefix, self.filestore) class TCPStoreTest(TestCase, StoreTestBase): def _create_store(self): store = create_tcp_store() store.set_timeout(timedelta(seconds=300)) return store def test_address_already_in_use(self): err_msg_reg = "^The server socket has failed to listen on any local " with self.assertRaisesRegex(RuntimeError, err_msg_reg): addr = DEFAULT_HOSTNAME port = common.find_free_port() # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841 store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841 @retry_on_connect_failures def test_multitenancy(self): addr = DEFAULT_HOSTNAME port = common.find_free_port() # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841 store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841 @skip_if_win32() @retry_on_connect_failures def test_init_pg_and_rpc_with_same_socket(self): addr = DEFAULT_HOSTNAME port = common.find_free_port() os.environ["MASTER_ADDR"] = addr os.environ["MASTER_PORT"] = str(port) # We internally use a multi-tenant TCP store. Both PG and RPC should successfully # initialize even when using the same socket address. dist.init_process_group( backend="gloo", init_method="env://", rank=0, world_size=1, ) backend_opts = rpc.TensorPipeRpcBackendOptions( init_method=f"tcp://{addr}:{port}" ) rpc.init_rpc( name="worker0", rank=0, world_size=1, rpc_backend_options=backend_opts, ) rpc.shutdown() # The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by # the user and one additional key used for coordinate all the workers. @property def num_keys_total(self): return 6 def _test_numkeys_delkeys(self, fs): # We start off with one init key in the store to coordinate workers self.assertEqual(fs.num_keys(), 1) fs.add("key", 1) fs.add("key", 2) fs.add("key", 3) fs.set("key0", "value0") fs.add("key3", 1) fs.set("key1", "value1") self.assertEqual(fs.num_keys(), 5) fs.delete_key("key") self.assertEqual(fs.num_keys(), 4) fs.set_timeout(timedelta(seconds=2)) with self.assertRaises(RuntimeError): fs.get("key") fs.delete_key("key0") fs.delete_key("key3") self.assertEqual(fs.num_keys(), 2) fs.set("key4", "value2") self.assertEqual(fs.num_keys(), 3) self.assertEqual(b"value1", fs.get("key1")) self.assertEqual(b"value2", fs.get("key4")) def test_numkeys_delkeys(self): self._test_numkeys_delkeys(self._create_store()) def _create_client(self, index, addr, port, world_size): client_store = dist.TCPStore(addr, port, world_size=world_size, timeout=timedelta(seconds=10)) self.assertEqual("value".encode(), client_store.get("key")) client_store.set(f"new_key{index}", f"new_value{index}") self.assertEqual(f"next_value{index}".encode(), client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}")) def _multi_worker_helper(self, world_size): addr = DEFAULT_HOSTNAME server_store = create_tcp_store(addr, world_size, wait_for_workers=False) server_store.set("key", "value") port = server_store.port num_indices = world_size if world_size else 1 for i in range(num_indices): self._create_client(i, addr, port, world_size) def test_multi_worker_with_fixed_world_size(self): self._multi_worker_helper(5) def test_multi_worker_with_nonfixed_world_size(self): self._multi_worker_helper(None) class PrefixTCPStoreTest(TestCase, StoreTestBase): def setUp(self): super(PrefixTCPStoreTest, self).setUp() self.tcpstore = create_tcp_store() self.prefix = "test_prefix" self.tcpstore.set_timeout(timedelta(seconds=300)) def _create_store(self): return dist.PrefixStore(self.prefix, self.tcpstore) # The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys # added by the user and one additional key used for coordinate all the # workers. @property def num_keys_total(self): return 6 class MyPythonStore(dist.Store): def __init__(self): super(MyPythonStore, self).__init__() self.store = dict() def set(self, key, value): if not isinstance(key, string_classes): raise AssertionError("Expected set to be called with string key") if type(value) is not bytes: raise AssertionError("Expected set to be called with bytes value") self.store[key] = value def get(self, key): value = self.store.get(key, b"") if type(value) is not bytes: raise AssertionError("Expected get to return bytes value") return value def add(self, key, value): new = int(self.store.get(key, 0)) + value self.set(key, bytes(str(new).encode("utf-8"))) return new class PythonStoreTest(TestCase): def setUp(self): super(PythonStoreTest, self).setUp() def test_set_get(self): # If we were to inherit from StoreTestBase and try to use # its test_set_get function, we would exercise the Python # API directly, instead of going through the C++ trampoline. # We care about testing the C++ trampoline, so run the # equivalent of StoreTestBase.test_set_get from C++. # See `torch/csrc/distributed/c10d/init.cpp` for the definition # of this test function. dist._test_python_store(MyPythonStore()) class RendezvousTest(TestCase): def test_unknown_handler(self): with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"): dist.rendezvous("invalid://") class RendezvousEnvTest(TestCase): @retry_on_connect_failures def test_nominal(self): os.environ["WORLD_SIZE"] = "1" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(common.find_free_port()) # Single rank os.environ["RANK"] = "0" gen0 = dist.rendezvous("env://") store0, rank0, size0 = next(gen0) self.assertEqual(0, rank0) self.assertEqual(1, size0) store0.set("key0", "value0") # check with get self.assertEqual(b"value0", store0.get("key0")) class RendezvousFileTest(TestCase): def test_common_errors(self): with self.assertRaisesRegex(ValueError, "path missing"): gen = dist.rendezvous("file://?rank=0&world_size=1") next(gen) with self.assertRaisesRegex(ValueError, "rank parameter missing"): gen = dist.rendezvous("file:///tmp/foo?world_size=1") next(gen) with self.assertRaisesRegex(ValueError, "size parameter missing"): gen = dist.rendezvous("file:///tmp/foo?rank=0") next(gen) def test_nominal(self): with tempfile.NamedTemporaryFile(delete=False) as file: url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2' gen0 = dist.rendezvous(url + "&rank=0") store0, rank0, size0 = next(gen0) self.assertEqual(0, rank0) self.assertEqual(2, size0) gen1 = dist.rendezvous(url + "&rank=1") store1, rank1, size1 = next(gen1) self.assertEqual(1, rank1) self.assertEqual(2, size1) # Set value on both stores store0.set("key0", "value0") store1.set("key1", "value1") # Cross check with get self.assertEqual(b"value0", store1.get("key0")) self.assertEqual(b"value1", store0.get("key1")) @skip_if_win32() class RendezvousTCPTest(TestCase): def create_tcp_url(self): addr = DEFAULT_HOSTNAME port = common.find_free_port() url = "tcp://%s:%d?world_size=%d" % (addr, port, 1) return url def test_common_errors(self): with self.assertRaisesRegex(ValueError, "port number missing"): gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1") next(gen) with self.assertRaisesRegex(ValueError, "rank parameter missing"): gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1") next(gen) with self.assertRaisesRegex(ValueError, "size parameter missing"): gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0") next(gen) def test_dns_timeout(self): with self.assertRaisesRegex(TimeoutError, "client socket has timed out after.*dnsnotexist"): gen = dist.rendezvous( "tcp://dnsnotexist:23456?world_size=2&rank=0", timeout=timedelta(seconds=1), ) next(gen) @retry_on_connect_failures def test_nominal(self): url = self.create_tcp_url() gen0 = dist.rendezvous(url + "&rank=0") store0, rank0, size0 = next(gen0) self.assertEqual(0, rank0) self.assertEqual(1, size0) # Set value on the single store store0.set("key0", "value0") # check with get self.assertEqual(b"value0", store0.get("key0")) @retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE)) def test_tcp_store_timeout_set(self): url = self.create_tcp_url() test_store_timeout = timedelta(seconds=10) gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout) store0, rank0, size0 = next(gen0) # this should time out in 10s. If the timeout passed into rendezvous was # not respected, it will take much longer to timeout. start = time.time() with self.assertRaisesRegex(RuntimeError, "Timeout"): store0.get("nonexistant key") end = time.time() time_diff = end - start self.assertGreater(test_store_timeout.seconds * 10, time_diff) if __name__ == "__main__": assert ( not torch.cuda._initialized ), "test_distributed must not have initialized CUDA context on main process" run_tests()