mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
In c10d PG initialization, we wrap TCPStore with multiple layers of PrefixStore which adds layers of prefix. One example is: "default_pg/0//cuda//timeout_dump" When initialized the default PG, because there is no store passed. We first add the prefix "default_pg" to the TCPStore returned from rendezvous:bdeaaad70c/torch/distributed/distributed_c10d.py (L1240)We then add pg_name (aka 0)bdeaaad70c/torch/distributed/distributed_c10d.py (L1376)and device (aka cuda)bdeaaad70c/torch/distributed/distributed_c10d.py (L1387)to the prefix. Then when we call store_->set("timeout_dump"). The actual key used for writing into TCPStore is "default_pg/0//cuda//timeout_dump". For sub-PG, things get even interesting, we put the store wrapped with default pg name to a cache:bdeaaad70c/torch/distributed/distributed_c10d.py (L1517)And when creating each subPG, it is append its PG name right after the cached store. The example keys are: 'default_pg/0//10//cuda//timeout_dump', 'default_pg/0//12//cuda//timeout_dump', 'default_pg/0//38//cuda//timeout_dump', 'default_pg/0//39//cuda//timeout_dump'. (10, 12, 38 and 39 are all PG names of each subPG created) The reason why the number in the name is bumped up so high is because for each subPG creation, all ranks have to call the API together and the global variable used for PG name will be bumped up monolithically:bdeaaad70c/torch/distributed/distributed_c10d.py (L3666)Similar things happen for using hashing for PG names. This has a potential issue, because each sub-PG has an instance of ProcessGroupNCCL, and if we want to set something global to notify all sub-PGs (and all ranks). This added prefix causes bugs. For example, if on sub-PG 1, we set a value to TCPStore with key ('default_pg/0//1//cuda//timeout_dump'), while we use the default PG instances to check the TCPStore, which are using the key ('default_pg/0//cuda//timeout_dump'), default PG instances will never get the notified signals. So in this PR, we added a new API in PrefixStore which we get the innermost non-PrefixStore for set and check. The next PR will make changes in NCCL watchdog. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117074 Approved by: https://github.com/wconstab, https://github.com/H-Huang
852 lines
30 KiB
Python
852 lines
30 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import datetime
|
|
import os
|
|
import socket
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import threading
|
|
from datetime import timedelta
|
|
from sys import platform
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.distributed_c10d as c10d
|
|
import torch.distributed.rpc as rpc
|
|
from torch.distributed import DistNetworkError, DistError, DistStoreError
|
|
from torch.testing._internal.common_distributed import MultiThreadedTestCase
|
|
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
|
|
|
|
if not dist.is_available():
|
|
print("torch.distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
import torch.testing._internal.common_utils as common
|
|
from torch.testing._internal.common_distributed import (
|
|
skip_if_win32,
|
|
create_tcp_store,
|
|
tp_transports
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
load_tests,
|
|
run_tests,
|
|
retry_on_connect_failures,
|
|
ADDRESS_IN_USE,
|
|
CONNECT_TIMEOUT,
|
|
)
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
if platform == "darwin":
|
|
LOOPBACK = "lo0"
|
|
else:
|
|
LOOPBACK = "lo"
|
|
|
|
DEFAULT_HOSTNAME = "localhost"
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
|
|
|
|
def gpus_for_rank(world_size):
|
|
"""Multigpu tests are designed to simulate the multi nodes with multi
|
|
GPUs on each node. Nccl backend requires equal #GPUs in each process.
|
|
On a single node, all visible GPUs are evenly
|
|
divided to subsets, each process only uses a subset.
|
|
"""
|
|
visible_devices = list(range(torch.cuda.device_count()))
|
|
gpus_per_process = torch.cuda.device_count() // world_size
|
|
gpus_for_rank = []
|
|
for rank in range(world_size):
|
|
gpus_for_rank.append(
|
|
visible_devices[rank * gpus_per_process: (rank + 1) * gpus_per_process]
|
|
)
|
|
return gpus_for_rank
|
|
|
|
|
|
class StoreTestBase:
|
|
def _create_store(self, i):
|
|
raise RuntimeError("not implemented")
|
|
|
|
def _test_set_get_check(self, fs):
|
|
fs.add("key", 1)
|
|
fs.add("key", 2)
|
|
fs.add("key", 3)
|
|
fs.set("key0", "value0")
|
|
fs.add("key3", 1)
|
|
fs.set("key1", "value1")
|
|
fs.add("key3", 2)
|
|
fs.set("key2", "value2")
|
|
fs.add("key3", 3)
|
|
fs.add("key3", 4)
|
|
fs.add("key3", 5)
|
|
fs.add("key3", 6)
|
|
self.assertEqual(fs.num_keys(), self.num_keys_total)
|
|
self.assertEqual(b"6", fs.get("key"))
|
|
self.assertEqual(b"value0", fs.get("key0"))
|
|
self.assertEqual(b"value1", fs.get("key1"))
|
|
self.assertEqual(b"value2", fs.get("key2"))
|
|
self.assertEqual(b"21", fs.get("key3"))
|
|
self.assertTrue(fs.check(["key3"]))
|
|
self.assertFalse(fs.check(["Randomkey3"]))
|
|
|
|
fs.set("-key3", "7")
|
|
self.assertEqual(b"7", fs.get("-key3"))
|
|
fs.delete_key("-key3")
|
|
self.assertEqual(fs.num_keys(), self.num_keys_total)
|
|
|
|
def test_set_get_check(self):
|
|
self._test_set_get_check(self._create_store())
|
|
|
|
def _test_compare_set(self, store):
|
|
missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
|
|
self.assertEqual(b"wrong_old_value", missing_key_result)
|
|
|
|
store.set("cs_key0", "value0")
|
|
self.assertEqual(b"value0", store.get("cs_key0"))
|
|
old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
|
|
self.assertEqual(b"value0", old_value_result)
|
|
self.assertEqual(b"value0", store.get("cs_key0"))
|
|
new_value_result = store.compare_set("cs_key0", "value0", "new_value0")
|
|
self.assertEqual(b"new_value0", new_value_result)
|
|
self.assertEqual(b"new_value0", store.get("cs_key0"))
|
|
empty_old_value_result = store.compare_set("cs_key1", "", "new_value1")
|
|
self.assertEqual(b"new_value1", empty_old_value_result)
|
|
self.assertEqual(b"new_value1", store.get("cs_key1"))
|
|
|
|
def test_compare_set(self):
|
|
self._test_compare_set(self._create_store())
|
|
|
|
def _test_simple_wait(self, fs):
|
|
with self.assertRaisesRegex(RuntimeError, "[t -i]imeout"):
|
|
fs.wait(["bad_key"], timedelta(seconds=0.25))
|
|
fs.add("good_key", 1)
|
|
fs.wait(["good_key"])
|
|
|
|
def test_simple_wait(self):
|
|
self._test_simple_wait(self._create_store())
|
|
|
|
def _test_append(self, store):
|
|
if not store.has_extended_api():
|
|
self.skipTest("Store doesn't support extended APIs")
|
|
store.set("foo", "po")
|
|
store.append("foo", "tato")
|
|
store.append("bar", "po")
|
|
store.append("bar", "tato")
|
|
self.assertEqual(b"potato", store.get("foo"))
|
|
self.assertEqual(b"potato", store.get("bar"))
|
|
|
|
def test_append(self):
|
|
self._test_append(self._create_store())
|
|
|
|
def _test_multi_set(self, store):
|
|
if not store.has_extended_api():
|
|
self.skipTest("Store doesn't support extended APIs")
|
|
store.multi_set(["foo", "bar"], ["po", "tato"])
|
|
self.assertEqual(b"po", store.get("foo"))
|
|
self.assertEqual(b"tato", store.get("bar"))
|
|
|
|
def test_multi_set(self):
|
|
self._test_multi_set(self._create_store())
|
|
|
|
def _test_multi_get(self, store):
|
|
if not store.has_extended_api():
|
|
self.skipTest("Store doesn't support extended APIs")
|
|
store.set("foo", "po")
|
|
store.set("bar", "tato")
|
|
v0, v1 = store.multi_get(["foo", "bar"])
|
|
self.assertEqual(b"po", v0)
|
|
self.assertEqual(b"tato", v1)
|
|
|
|
def test_multi_get(self):
|
|
self._test_multi_get(self._create_store())
|
|
|
|
# This is the number of keys used in test_set_get. Adding this as a class
|
|
# property instead of hardcoding in the test since some Store
|
|
# implementations will have differing number of keys. In the base case,
|
|
# there will be 5 keys: key, key0, key1, key2, key3.
|
|
@property
|
|
def num_keys_total(self):
|
|
return 5
|
|
|
|
|
|
class FileStoreTest(TestCase, StoreTestBase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.file = tempfile.NamedTemporaryFile(delete=False)
|
|
|
|
def _create_store(self):
|
|
store = dist.FileStore(self.file.name, 1)
|
|
store.set_timeout(timedelta(seconds=300))
|
|
return store
|
|
|
|
def test_init_pg_and_rpc_with_same_file(self):
|
|
file = tempfile.NamedTemporaryFile(delete=False)
|
|
# Init RPC using file
|
|
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
|
|
rpc_backend_options.init_method = f"file://{file.name}"
|
|
rpc_backend_options._transports = tp_transports()
|
|
rpc.init_rpc("worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options)
|
|
|
|
# Init PG using file
|
|
dist.init_process_group("gloo", rank=0, world_size=1, init_method=f"file://{file.name}")
|
|
dist.destroy_process_group()
|
|
assert os.path.exists(file.name)
|
|
|
|
rpc.shutdown()
|
|
os.remove(file.name)
|
|
|
|
def test_refcount(self):
|
|
file = tempfile.NamedTemporaryFile(delete=False)
|
|
store = dist.FileStore(file.name, 1)
|
|
store2 = dist.FileStore(file.name, 1)
|
|
|
|
del store
|
|
assert os.path.exists(file.name)
|
|
del store2
|
|
assert not os.path.exists(file.name)
|
|
|
|
@property
|
|
def num_keys_total(self):
|
|
return 6
|
|
|
|
|
|
@skip_if_win32()
|
|
class HashStoreTest(TestCase, StoreTestBase):
|
|
def _create_store(self):
|
|
store = dist.HashStore()
|
|
store.set_timeout(timedelta(seconds=300))
|
|
return store
|
|
|
|
|
|
class PrefixStoreTest(TestCase):
|
|
def setUp(self):
|
|
# delete is false as FileStore will automatically clean up the file
|
|
self.file = tempfile.NamedTemporaryFile(delete=False)
|
|
|
|
def test_get_underlying_store(self):
|
|
tcp_store = dist.TCPStore(host_name=DEFAULT_HOSTNAME, port=0, world_size=1, is_master=True)
|
|
hash_store = dist.HashStore()
|
|
file_store = dist.FileStore(self.file.name, world_size=1)
|
|
for store in [tcp_store, hash_store, file_store]:
|
|
with self.subTest(f"Testing getting underlying_store for {type(store)}"):
|
|
prefix_store = dist.PrefixStore("prefix", store)
|
|
self.assertEqual(prefix_store.underlying_store, store)
|
|
|
|
|
|
class PrefixFileStoreTest(TestCase, StoreTestBase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.file = tempfile.NamedTemporaryFile(delete=False)
|
|
self.filestore = dist.FileStore(self.file.name, 1)
|
|
self.prefix = "test_prefix"
|
|
self.filestore.set_timeout(timedelta(seconds=300))
|
|
|
|
def _create_store(self):
|
|
return dist.PrefixStore(self.prefix, self.filestore)
|
|
|
|
@property
|
|
def num_keys_total(self):
|
|
return 6
|
|
|
|
|
|
class TCPStoreTest(TestCase, StoreTestBase):
|
|
def _create_store(self):
|
|
store = create_tcp_store()
|
|
store.set_timeout(timedelta(seconds=300))
|
|
return store
|
|
|
|
def _create_store_with_ws(self, addr, world_size):
|
|
return create_tcp_store(addr, world_size, wait_for_workers=False)
|
|
|
|
def test_address_already_in_use(self):
|
|
err_msg_reg = "^The server socket has failed to listen on any local "
|
|
with self.assertRaisesRegex(RuntimeError, err_msg_reg):
|
|
addr = DEFAULT_HOSTNAME
|
|
port = common.find_free_port()
|
|
|
|
# Use noqa to silence flake8.
|
|
# Need to store in an unused variable here to ensure the first
|
|
# object is not destroyed before the second object is created.
|
|
store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841
|
|
store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841
|
|
|
|
@retry_on_connect_failures
|
|
def test_multitenancy(self):
|
|
addr = DEFAULT_HOSTNAME
|
|
port = common.find_free_port()
|
|
|
|
# Use noqa to silence flake8.
|
|
# Need to store in an unused variable here to ensure the first
|
|
# object is not destroyed before the second object is created.
|
|
store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841
|
|
store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841
|
|
|
|
@skip_if_win32()
|
|
@retry_on_connect_failures
|
|
def test_init_pg_and_rpc_with_same_socket(self):
|
|
addr = DEFAULT_HOSTNAME
|
|
port = common.find_free_port()
|
|
|
|
os.environ["MASTER_ADDR"] = addr
|
|
os.environ["MASTER_PORT"] = str(port)
|
|
|
|
# We internally use a multi-tenant TCP store. Both PG and RPC should successfully
|
|
# initialize even when using the same socket address.
|
|
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
init_method="env://",
|
|
rank=0,
|
|
world_size=1,
|
|
)
|
|
|
|
backend_opts = rpc.TensorPipeRpcBackendOptions(
|
|
init_method=f"tcp://{addr}:{port}",
|
|
_transports=tp_transports()
|
|
)
|
|
rpc.init_rpc(
|
|
name="worker0",
|
|
rank=0,
|
|
world_size=1,
|
|
rpc_backend_options=backend_opts,
|
|
)
|
|
|
|
rpc.shutdown()
|
|
|
|
@skip_if_win32()
|
|
def test_take_over_listen_socket(self):
|
|
listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
listen_sock.bind(("localhost", 0))
|
|
addr, port, *_ = listen_sock.getsockname()
|
|
listen_fd = listen_sock.detach()
|
|
|
|
store = dist.TCPStore(addr, port, 1, is_master=True, master_listen_fd=listen_fd)
|
|
|
|
store.set("key", "value")
|
|
self.assertEqual(b"value", store.get("key"))
|
|
|
|
# The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
|
|
# the user and one additional key used for coordinate all the workers.
|
|
@property
|
|
def num_keys_total(self):
|
|
return 6
|
|
|
|
def _test_numkeys_delkeys(self, fs):
|
|
# We start off with one init key in the store to coordinate workers
|
|
self.assertEqual(fs.num_keys(), 1)
|
|
fs.add("key", 1)
|
|
fs.add("key", 2)
|
|
fs.add("key", 3)
|
|
fs.set("key0", "value0")
|
|
fs.add("key3", 1)
|
|
fs.set("key1", "value1")
|
|
self.assertEqual(fs.num_keys(), 5)
|
|
fs.delete_key("key")
|
|
self.assertEqual(fs.num_keys(), 4)
|
|
fs.set_timeout(timedelta(seconds=2))
|
|
with self.assertRaises(RuntimeError):
|
|
fs.get("key")
|
|
fs.delete_key("key0")
|
|
fs.delete_key("key3")
|
|
self.assertEqual(fs.num_keys(), 2)
|
|
fs.set("key4", "value2")
|
|
self.assertEqual(fs.num_keys(), 3)
|
|
self.assertEqual(b"value1", fs.get("key1"))
|
|
self.assertEqual(b"value2", fs.get("key4"))
|
|
|
|
def test_numkeys_delkeys(self):
|
|
self._test_numkeys_delkeys(self._create_store())
|
|
|
|
def _create_client(self, index, addr, port, world_size):
|
|
client_store = dist.TCPStore(addr, port, world_size=world_size, timeout=timedelta(seconds=10))
|
|
self.assertEqual(b"value", client_store.get("key"))
|
|
client_store.set(f"new_key{index}", f"new_value{index}")
|
|
self.assertEqual(f"next_value{index}".encode(),
|
|
client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}"))
|
|
|
|
def _multi_worker_helper(self, world_size):
|
|
addr = DEFAULT_HOSTNAME
|
|
server_store = self._create_store_with_ws(addr, world_size)
|
|
server_store.set("key", "value")
|
|
port = server_store.port
|
|
|
|
num_indices = world_size if world_size else 1
|
|
for i in range(num_indices):
|
|
self._create_client(i, addr, port, world_size)
|
|
|
|
def test_multi_worker_with_fixed_world_size(self):
|
|
self._multi_worker_helper(5)
|
|
|
|
def test_multi_worker_with_nonfixed_world_size(self):
|
|
self._multi_worker_helper(None)
|
|
|
|
def test_append(self):
|
|
store = self._create_store()
|
|
store.set("foo", "po")
|
|
store.append("foo", "tato")
|
|
store.append("bar", "po")
|
|
store.append("bar", "tato")
|
|
self.assertEqual(b"potato", store.get("foo"))
|
|
self.assertEqual(b"potato", store.get("bar"))
|
|
|
|
def test_multi_set(self):
|
|
store = self._create_store()
|
|
store.multi_set(["foo", "bar"], ["po", "tato"])
|
|
self.assertEqual(b"po", store.get("foo"))
|
|
self.assertEqual(b"tato", store.get("bar"))
|
|
|
|
def test_multi_get(self):
|
|
store = self._create_store()
|
|
store.set("foo", "po")
|
|
store.set("bar", "tato")
|
|
v0, v1 = store.multi_get(["foo", "bar"])
|
|
self.assertEqual(b"po", v0)
|
|
self.assertEqual(b"tato", v1)
|
|
|
|
def test_store_timeout_on_missing_clients(self):
|
|
with self.assertRaisesRegex(DistStoreError, r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined."):
|
|
# world_size is 2 so it should timeout
|
|
dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2))
|
|
|
|
# when wait_for_workers is not set, then there should be no exception raised
|
|
dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2), wait_for_workers=False)
|
|
|
|
class LibUvTCPStoreTest(TCPStoreTest):
|
|
|
|
def _create_store(self):
|
|
store = create_tcp_store(use_libuv=True)
|
|
store.set_timeout(timedelta(seconds=300))
|
|
return store
|
|
|
|
def _create_store_with_ws(self, addr, world_size):
|
|
return create_tcp_store(addr, world_size, wait_for_workers=False, use_libuv=True)
|
|
|
|
|
|
class PrefixTCPStoreTest(TestCase, StoreTestBase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.tcpstore = create_tcp_store()
|
|
self.prefix = "test_prefix"
|
|
self.tcpstore.set_timeout(timedelta(seconds=300))
|
|
|
|
def _create_store(self):
|
|
return dist.PrefixStore(self.prefix, self.tcpstore)
|
|
|
|
# The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
|
|
# added by the user and one additional key used for coordinate all the
|
|
# workers.
|
|
@property
|
|
def num_keys_total(self):
|
|
return 6
|
|
|
|
def test_underlying_non_prefix_store(self):
|
|
store = self._create_store()
|
|
wrapped_store = dist.PrefixStore(self.prefix, dist.PrefixStore(self.prefix, store))
|
|
self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
|
|
self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
|
|
|
|
class MyPythonStore(dist.Store):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.store = {}
|
|
|
|
def set(self, key, value):
|
|
if not isinstance(key, (str, bytes)):
|
|
raise AssertionError("Expected set to be called with string key")
|
|
if type(value) is not bytes:
|
|
raise AssertionError("Expected set to be called with bytes value")
|
|
self.store[key] = value
|
|
|
|
def get(self, key):
|
|
value = self.store.get(key, b"")
|
|
if type(value) is not bytes:
|
|
raise AssertionError("Expected get to return bytes value")
|
|
return value
|
|
|
|
def add(self, key, value):
|
|
new = int(self.store.get(key, 0)) + value
|
|
self.set(key, bytes(str(new).encode("utf-8")))
|
|
return new
|
|
|
|
def compare_set(self, key, expected, newValue):
|
|
if type(expected) is not bytes:
|
|
raise AssertionError("compare_set::expected not bytes")
|
|
if type(newValue) is not bytes:
|
|
raise AssertionError("compare_set::newValue not bytes")
|
|
|
|
val = self.store.get(key, None)
|
|
if expected == val or val is None:
|
|
val = self.store[key] = newValue
|
|
return val
|
|
|
|
class PythonStoreTest(TestCase):
|
|
def test_set_get(self):
|
|
# If we were to inherit from StoreTestBase and try to use
|
|
# its test_set_get function, we would exercise the Python
|
|
# API directly, instead of going through the C++ trampoline.
|
|
# We care about testing the C++ trampoline, so run the
|
|
# equivalent of StoreTestBase.test_set_get from C++.
|
|
# See `torch/csrc/distributed/c10d/init.cpp` for the definition
|
|
# of this test function.
|
|
dist._test_python_store(MyPythonStore())
|
|
|
|
|
|
class RendezvousTest(TestCase):
|
|
def test_unknown_handler(self):
|
|
with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
|
|
dist.rendezvous("invalid://")
|
|
|
|
def test_url_with_node_params(self):
|
|
with self.assertRaisesRegex(AssertionError, "has node-specific arguments"):
|
|
dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16)
|
|
|
|
|
|
class RendezvousEnvTest(TestCase):
|
|
@retry_on_connect_failures
|
|
def test_nominal(self):
|
|
os.environ["WORLD_SIZE"] = "1"
|
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
os.environ["MASTER_PORT"] = str(common.find_free_port())
|
|
|
|
# Single rank
|
|
os.environ["RANK"] = "0"
|
|
gen0 = dist.rendezvous("env://")
|
|
store0, rank0, size0 = next(gen0)
|
|
self.assertEqual(0, rank0)
|
|
self.assertEqual(1, size0)
|
|
|
|
store0.set("key0", "value0")
|
|
|
|
# check with get
|
|
self.assertEqual(b"value0", store0.get("key0"))
|
|
|
|
|
|
class RendezvousFileTest(TestCase):
|
|
def test_common_errors(self):
|
|
with self.assertRaisesRegex(ValueError, "path missing"):
|
|
gen = dist.rendezvous("file://?rank=0&world_size=1")
|
|
next(gen)
|
|
with self.assertRaisesRegex(ValueError, "rank parameter missing"):
|
|
gen = dist.rendezvous("file:///tmp/foo?world_size=1")
|
|
next(gen)
|
|
with self.assertRaisesRegex(ValueError, "size parameter missing"):
|
|
gen = dist.rendezvous("file:///tmp/foo?rank=0")
|
|
next(gen)
|
|
|
|
def test_nominal(self):
|
|
with tempfile.NamedTemporaryFile(delete=False) as file:
|
|
url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
|
|
gen0 = dist.rendezvous(url + "&rank=0")
|
|
store0, rank0, size0 = next(gen0)
|
|
self.assertEqual(0, rank0)
|
|
self.assertEqual(2, size0)
|
|
gen1 = dist.rendezvous(url + "&rank=1")
|
|
store1, rank1, size1 = next(gen1)
|
|
self.assertEqual(1, rank1)
|
|
self.assertEqual(2, size1)
|
|
|
|
# Set value on both stores
|
|
store0.set("key0", "value0")
|
|
store1.set("key1", "value1")
|
|
|
|
# Cross check with get
|
|
self.assertEqual(b"value0", store1.get("key0"))
|
|
self.assertEqual(b"value1", store0.get("key1"))
|
|
|
|
|
|
@skip_if_win32()
|
|
class RendezvousTCPTest(TestCase):
|
|
def create_tcp_url(self):
|
|
addr = DEFAULT_HOSTNAME
|
|
port = common.find_free_port()
|
|
url = "tcp://%s:%d?world_size=%d" % (addr, port, 1)
|
|
return url
|
|
|
|
def test_common_errors(self):
|
|
with self.assertRaisesRegex(ValueError, "port number missing"):
|
|
gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
|
|
next(gen)
|
|
with self.assertRaisesRegex(ValueError, "rank parameter missing"):
|
|
gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1")
|
|
next(gen)
|
|
with self.assertRaisesRegex(ValueError, "size parameter missing"):
|
|
gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
|
|
next(gen)
|
|
|
|
def test_dns_timeout(self):
|
|
with self.assertRaisesRegex(DistNetworkError, "client socket has timed out after.*dnsnotexist") as manager:
|
|
gen = dist.rendezvous(
|
|
"tcp://dnsnotexist:23456?world_size=2&rank=0",
|
|
timeout=timedelta(seconds=1),
|
|
)
|
|
next(gen)
|
|
self.assertTrue(isinstance(manager.exception, DistError))
|
|
|
|
@retry_on_connect_failures
|
|
def test_nominal(self):
|
|
url = self.create_tcp_url()
|
|
gen0 = dist.rendezvous(url + "&rank=0")
|
|
store0, rank0, size0 = next(gen0)
|
|
self.assertEqual(0, rank0)
|
|
self.assertEqual(1, size0)
|
|
|
|
# Set value on the single store
|
|
store0.set("key0", "value0")
|
|
|
|
# check with get
|
|
self.assertEqual(b"value0", store0.get("key0"))
|
|
|
|
@retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE))
|
|
def test_tcp_store_timeout_set(self):
|
|
url = self.create_tcp_url()
|
|
test_store_timeout = timedelta(seconds=10)
|
|
gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
|
|
store0, rank0, size0 = next(gen0)
|
|
# this should time out in 10s. If the timeout passed into rendezvous was
|
|
# not respected, it will take much longer to timeout.
|
|
start = time.time()
|
|
with self.assertRaisesRegex(RuntimeError, "Timeout"):
|
|
store0.get("nonexistant key")
|
|
|
|
end = time.time()
|
|
time_diff = end - start
|
|
self.assertGreater(test_store_timeout.seconds * 10, time_diff)
|
|
|
|
def test_tcp_store_timeout_doest_break_client(self):
|
|
url = self.create_tcp_url()
|
|
test_store_timeout = timedelta(seconds=10)
|
|
gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
|
|
store0, rank0, size0 = next(gen0)
|
|
# this should time out in 10s. If the timeout passed into rendezvous was
|
|
# not respected, it will take much longer to timeout.
|
|
start = time.time()
|
|
with self.assertRaisesRegex(RuntimeError, "Timeout"):
|
|
store0.get("the_key")
|
|
|
|
store0.set("the_key", "x")
|
|
|
|
self.assertEqual(b"x", store0.get("the_key"))
|
|
|
|
end = time.time()
|
|
time_diff = end - start
|
|
self.assertGreater(test_store_timeout.seconds * 10, time_diff)
|
|
|
|
def test_tcp_store_url_with_libuv(self):
|
|
url = self.create_tcp_url()
|
|
gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1")
|
|
store0, rank0, size0 = next(gen0)
|
|
self.assertTrue(store0.libuvBackend)
|
|
|
|
class DummyStore(dist.Store):
|
|
def __init__(self):
|
|
self.appends = []
|
|
self.multi_sets = []
|
|
self.multi_gets = []
|
|
self.multi_get_res = []
|
|
super().__init__()
|
|
|
|
def append(self, key, value):
|
|
self.appends.append((key, value))
|
|
|
|
def multi_get(self, keys):
|
|
self.multi_gets.append(keys)
|
|
return self.multi_get_res.pop(0)
|
|
|
|
def multi_set(self, keys, values):
|
|
self.multi_sets.append((keys, values))
|
|
|
|
def has_extended_api(self):
|
|
return True
|
|
|
|
class TestPythonStore(TestCase):
|
|
def test_optional_methods_fail(self):
|
|
class TestStore(dist.Store):
|
|
pass
|
|
store = TestStore()
|
|
self.assertFalse(store.has_extended_api())
|
|
with self.assertRaisesRegex(RuntimeError, "Not implemented."):
|
|
store.append("foo", "bar")
|
|
with self.assertRaisesRegex(RuntimeError, "Not implemented."):
|
|
store.multi_get(["foo", "bar"])
|
|
with self.assertRaisesRegex(RuntimeError, "Not implemented."):
|
|
store.multi_set(["foo", "bar"], [b"v", b"v"])
|
|
|
|
def test_has_extended_api_passthrough(self):
|
|
class TestStore(dist.Store):
|
|
pass
|
|
test_store = TestStore()
|
|
store = dist.PrefixStore("p", test_store)
|
|
self.assertFalse(store.has_extended_api())
|
|
with self.assertRaisesRegex(RuntimeError, "Not implemented."):
|
|
store.append("foo", "bar")
|
|
with self.assertRaisesRegex(RuntimeError, "Not implemented."):
|
|
store.multi_get(["foo", "bar"])
|
|
with self.assertRaisesRegex(RuntimeError, "Not implemented."):
|
|
store.multi_set(["foo", "bar"], [b"v", b"v"])
|
|
|
|
def test_has_extended_api_roundtrip(self):
|
|
store = DummyStore()
|
|
prefix = dist.PrefixStore("p", store)
|
|
self.assertTrue(prefix.has_extended_api())
|
|
|
|
def test_append_roundtrip(self):
|
|
store = DummyStore()
|
|
prefix = dist.PrefixStore("p", store)
|
|
prefix.append("foo", "bar")
|
|
self.assertEqual(1, len(store.appends))
|
|
self.assertEqual(("p/foo", b"bar"), store.appends[0])
|
|
|
|
def test_multi_get_roundtrip(self):
|
|
store = DummyStore()
|
|
prefix = dist.PrefixStore("p", store)
|
|
store.multi_get_res.append([b"x", b"y"])
|
|
res = prefix.multi_get(["foo", "bar"])
|
|
self.assertEqual(1, len(store.multi_gets))
|
|
self.assertEqual(["p/foo", "p/bar"], store.multi_gets[0])
|
|
self.assertEqual([b"x", b"y"], res)
|
|
|
|
def test_multi_set_roundtrip(self):
|
|
store = DummyStore()
|
|
prefix = dist.PrefixStore("p", store)
|
|
prefix.multi_set(["foo", "bar"], [b'x', b'y'])
|
|
self.assertEqual(1, len(store.multi_sets))
|
|
self.assertEqual(["p/foo", "p/bar"], store.multi_sets[0][0])
|
|
self.assertEqual([b'x', b'y'], store.multi_sets[0][1])
|
|
|
|
def test_extended_methods_fallbacks(self):
|
|
test_store = MyPythonStore()
|
|
store = dist.PrefixStore("p", test_store)
|
|
self.assertFalse(store.has_extended_api())
|
|
store.append("foo", b"po")
|
|
store.append("foo", b"tato")
|
|
self.assertEqual(store.get("foo"), b"potato")
|
|
|
|
store.multi_set(["a", "b"], [b"c", b"d"])
|
|
self.assertEqual(store.multi_get(["a", "b", "foo"]), [b"c", b"d", b"potato"])
|
|
|
|
|
|
class TestMultiThreadedWait(MultiThreadedTestCase):
|
|
# TODO: Use less hacky means of instantiating stores.
|
|
# Note, stores accumulate values per test.
|
|
stores = [
|
|
dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1),
|
|
dist.HashStore(),
|
|
dist.PrefixStore("pre", dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1)),
|
|
create_tcp_store(),
|
|
create_tcp_store(use_libuv=True),
|
|
dist.PrefixStore("pre", create_tcp_store()),
|
|
dist.PrefixStore("pre", create_tcp_store(use_libuv=True)),
|
|
]
|
|
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_threads()
|
|
|
|
# Iterates over self.stores, keep 7 in sync with len(self.stores).
|
|
@parametrize("i", range(7))
|
|
def test_wait(self, i):
|
|
store = self.stores[i]
|
|
store.set_timeout(timedelta(seconds=2))
|
|
if dist.get_rank() == 0:
|
|
store.wait(["key1"])
|
|
self.assertEqual(b"value1", store.get("key1"))
|
|
if dist.get_rank() == 1:
|
|
store.set("key1", "value1")
|
|
|
|
|
|
instantiate_parametrized_tests(TestMultiThreadedWait)
|
|
|
|
@skip_if_win32()
|
|
class TimeoutTest(TestCase):
|
|
def tearDown(self):
|
|
import signal
|
|
super().tearDown()
|
|
signal.signal(signal.SIGUSR1, signal.SIG_IGN)
|
|
|
|
def test_interrupt_doesnt_break_wait(self):
|
|
import signal
|
|
rank_res = [None, None]
|
|
|
|
def run(rank, my_store):
|
|
nonlocal rank_res
|
|
try:
|
|
if rank == 0:
|
|
time.sleep(4)
|
|
my_store.set("foo", "bar")
|
|
else:
|
|
my_store.wait(["foo"], datetime.timedelta(seconds=10))
|
|
rank_res[rank] = True
|
|
except Error as e: # noqa: F821
|
|
rank_res[rank] = e
|
|
time.sleep(1)
|
|
|
|
rank0_store = dist.TCPStore(
|
|
host_name=DEFAULT_HOSTNAME, port=0, world_size=2, is_master=True, wait_for_workers=False)
|
|
rank1_store = dist.TCPStore(
|
|
host_name=DEFAULT_HOSTNAME, port=rank0_store.port, world_size=2, is_master=False, wait_for_workers=False)
|
|
|
|
ths = []
|
|
for i in range(2):
|
|
t = threading.Thread(target=run, args=(i, [rank0_store, rank1_store][i],))
|
|
t.start()
|
|
ths.append(t)
|
|
|
|
def handler(a, b):
|
|
pass
|
|
|
|
signal.signal(signal.SIGUSR1, handler)
|
|
time.sleep(1)
|
|
signal.pthread_kill(ths[1].ident, signal.SIGUSR1)
|
|
|
|
for t in ths:
|
|
t.join()
|
|
self.assertTrue(rank_res[0], "rank0")
|
|
self.assertTrue(rank_res[1], "rank1")
|
|
|
|
|
|
class InitPgWithUvStore(TestCase):
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
os.environ.pop("USE_LIBUV", None)
|
|
os.environ.pop("MASTER_ADDR", None)
|
|
os.environ.pop("MASTER_PORT", None)
|
|
|
|
def test_with_url_param(self):
|
|
port = common.find_free_port()
|
|
dist.init_process_group("gloo", rank=0, world_size=1, init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=1")
|
|
self._run_test()
|
|
|
|
def test_with_env_var(self):
|
|
port = common.find_free_port()
|
|
os.environ["USE_LIBUV"] = "1"
|
|
os.environ["MASTER_ADDR"] = DEFAULT_HOSTNAME
|
|
os.environ["MASTER_PORT"] = str(port)
|
|
dist.init_process_group("gloo", rank=0, world_size=1, init_method="env://")
|
|
self._run_test()
|
|
|
|
def _run_test(self):
|
|
pg = dist.group.WORLD
|
|
store = c10d._get_process_group_store(pg)
|
|
self.assertTrue(isinstance(store, dist.PrefixStore))
|
|
# c10d does multiple levels of wrapping
|
|
while isinstance(store, dist.PrefixStore):
|
|
store = store.underlying_store
|
|
self.assertTrue(isinstance(store, dist.TCPStore))
|
|
self.assertTrue(store.libuvBackend)
|
|
dist.destroy_process_group()
|
|
|
|
if __name__ == "__main__":
|
|
assert (
|
|
not torch.cuda._initialized
|
|
), "test_distributed must not have initialized CUDA context on main process"
|
|
|
|
run_tests()
|