pytorch/test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py
Yuanyuan Chen a8c528c105 [1/N] Apply UP035 rule in tests (#163947)
Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947
Approved by: https://github.com/ezyang
2025-09-29 01:42:01 +00:00

1929 lines
62 KiB
Python

# Owner(s): ["oncall: r2p"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import os
import pickle
import socket
import threading
import time
from abc import ABC, abstractmethod
from base64 import b64encode
from collections.abc import Callable
from datetime import datetime, timedelta, timezone
from typing import cast, Optional
from unittest import TestCase
from unittest.mock import call, MagicMock, Mock, patch, PropertyMock
import torch.distributed as dist
from torch.distributed import HashStore, Store
from torch.distributed.elastic.rendezvous import (
RendezvousClosedError,
RendezvousError,
RendezvousInfo,
RendezvousParameters,
RendezvousStateError,
RendezvousStoreInfo,
RendezvousTimeoutError,
)
from torch.distributed.elastic.rendezvous.dynamic_rendezvous import (
_Action,
_BackendRendezvousStateHolder,
_DistributedRendezvousOpExecutor,
_NodeDesc,
_NodeDescGenerator,
_RendezvousCloseOp,
_RendezvousContext,
_RendezvousExitOp,
_RendezvousJoinOp,
_RendezvousKeepAliveOp,
_RendezvousState,
_RendezvousStateHolder,
create_handler,
DynamicRendezvousHandler,
RendezvousBackend,
RendezvousSettings,
RendezvousTimeout,
Token,
)
TEST_PORT = 54321
TEST_ADDR = "host"
class CustomAssertMixin:
assertDictEqual: Callable
def assert_state_equal(
self, actual: _RendezvousState, expected: _RendezvousState
) -> None:
self.assertDictEqual(vars(actual), vars(expected))
def assert_state_empty(self, actual: _RendezvousState) -> None:
self.assertDictEqual(vars(actual), vars(_RendezvousState()))
class RendezvousTimeoutTest(TestCase):
def test_init_initializes_timeout(self) -> None:
timeout = RendezvousTimeout(
timedelta(seconds=50),
timedelta(seconds=60),
timedelta(seconds=70),
timedelta(seconds=80),
)
self.assertEqual(timeout.join, timedelta(seconds=50))
self.assertEqual(timeout.last_call, timedelta(seconds=60))
self.assertEqual(timeout.close, timedelta(seconds=70))
self.assertEqual(timeout.heartbeat, timedelta(seconds=80))
def test_init_initializes_timeout_if_no_timeout_is_specified(self) -> None:
timeout = RendezvousTimeout()
self.assertEqual(timeout.join, timedelta(seconds=600))
self.assertEqual(timeout.last_call, timedelta(seconds=30))
self.assertEqual(timeout.close, timedelta(seconds=30))
self.assertEqual(timeout.heartbeat, timedelta(seconds=5))
def test_init_raises_error_if_timeout_is_not_positive(self) -> None:
join_timeouts = [timedelta(seconds=0), timedelta(seconds=-1)]
for join_timeout in join_timeouts:
with self.subTest(join_timeout=join_timeout):
with self.assertRaisesRegex(
ValueError,
rf"^The join timeout \({join_timeout}\) must be positive.$",
):
RendezvousTimeout(join_timeout)
class NodeDescTest(TestCase):
def test_repr(self) -> None:
desc = _NodeDesc("dummy_fqdn", 3, 5)
self.assertEqual(repr(desc), "dummy_fqdn_3_5")
def test_hash(self) -> None:
desc1 = _NodeDesc("dummy_fqdn", 2, 4)
desc2 = _NodeDesc("dummy_fqdn", 3, 5)
descs = {desc1, desc2}
self.assertIn(desc1, descs)
self.assertIn(desc2, descs)
class NodeDescGeneratorTest(TestCase):
def test_generate(self) -> None:
desc_generator = _NodeDescGenerator()
fqdn = socket.getfqdn()
pid = os.getpid()
for local_id in range(4):
with self.subTest(fqdn=fqdn, pid=pid, local_id=local_id):
desc = desc_generator.generate()
self.assertEqual(repr(desc), f"{fqdn}_{pid}_{local_id}")
class RendezvousStateTest(TestCase):
def test_encoded_size_is_within_expected_limit(self) -> None:
state = _RendezvousState()
state.round = 1
state.complete = True
state.deadline = datetime.now(timezone.utc)
state.closed = True
# fmt: off
expected_max_sizes = (
( 5, 2 * (2 ** 10),), # 10 machines <= 2KB # noqa: E201, E241, E262
( 50, 16 * (2 ** 10),), # 100 machines <= 16KB # noqa: E201, E241, E262
( 500, 160 * (2 ** 10),), # 1000 machines <= 160KB # noqa: E201, E241, E262
(5000, 1600 * (2 ** 10),), # 10000 machines <= 1.6MB # noqa: E201, E241, E262
)
# fmt: on
for num_nodes, max_byte_size in expected_max_sizes:
with self.subTest(num_nodes=num_nodes, max_byte_size=max_byte_size):
for i in range(num_nodes):
node_running = _NodeDesc(
f"dummy{i}.dummy1-dummy1-dummy1-dummy1.com", 12345, i
)
node_waiting = _NodeDesc(
f"dummy{i}.dummy2-dummy2-dummy2-dummy2.com", 67890, i
)
state.participants[node_running] = i
state.wait_list.add(node_waiting)
state.last_heartbeats[node_running] = datetime.now(timezone.utc)
state.last_heartbeats[node_waiting] = datetime.now(timezone.utc)
bits = pickle.dumps(state)
base64_bits = b64encode(bits)
self.assertLessEqual(len(base64_bits), max_byte_size)
class FakeRendezvousBackend(RendezvousBackend):
_state: Optional[bytes]
_token: int
def __init__(self) -> None:
self._state = None
self._token = 0
@property
def name(self) -> str:
return "fake_backend"
def get_state(self) -> Optional[tuple[bytes, Token]]:
if self._token == 0:
return None
return self._state, self._token # type: ignore[return-value]
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[tuple[bytes, Token, bool]]:
if token is None:
token = 0
if token == self._token:
self._state = state
self._token += 1
has_set = True
else:
has_set = False
return self._state, self._token, has_set # type: ignore[return-value]
def get_state_internal(self) -> _RendezvousState:
return pickle.loads(cast(bytes, self._state))
def set_state_internal(self, state: _RendezvousState) -> None:
self._state = pickle.dumps(state)
self._token += 1
def corrupt_state(self) -> None:
self._state = b"corrupt_state"
self._token += 1
class BackendRendezvousStateHolderTest(TestCase, CustomAssertMixin):
def setUp(self) -> None:
self._backend = FakeRendezvousBackend()
mock_get_state = MagicMock(wraps=self._backend.get_state)
mock_set_state = MagicMock(wraps=self._backend.set_state)
self._mock_backend = Mock()
self._mock_backend.get_state = mock_get_state
self._mock_backend.set_state = mock_set_state
setattr(self._backend, "get_state", mock_get_state) # noqa: B010
setattr(self._backend, "set_state", mock_set_state) # noqa: B010
self._settings = RendezvousSettings(
run_id="dummy_run_id",
min_nodes=1,
max_nodes=1,
timeout=RendezvousTimeout(),
keep_alive_interval=timedelta(seconds=30),
keep_alive_max_attempt=3,
)
self._cache_duration = 0
self._now = datetime(2000, 1, 1, hour=0, minute=0)
self._datetime_patch = patch(
"torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime"
)
mock_datetime = self._datetime_patch.start()
mock_datetime.utcnow.return_value = self._now
def tearDown(self) -> None:
self._datetime_patch.stop()
def _create_state(self) -> _RendezvousState:
state = _RendezvousState()
state.round = 999
state.complete = True
state.deadline = self._now
state.closed = True
state.participants = {
_NodeDesc("dummy1", 1, 1): 0,
_NodeDesc("dummy2", 1, 1): 1,
_NodeDesc("dummy3", 1, 1): 2,
}
state.wait_list = {
_NodeDesc("dummy4", 1, 1),
_NodeDesc("dummy5", 1, 1),
}
state.last_heartbeats = {
_NodeDesc("dummy1", 1, 1): self._now,
_NodeDesc("dummy2", 1, 1): self._now - timedelta(seconds=15),
_NodeDesc("dummy3", 1, 1): self._now - timedelta(seconds=30),
_NodeDesc("dummy4", 1, 1): self._now - timedelta(seconds=60),
_NodeDesc("dummy5", 1, 1): self._now - timedelta(seconds=90),
}
return state
def _create_state_holder(self) -> _BackendRendezvousStateHolder:
return _BackendRendezvousStateHolder(
self._backend, self._settings, self._cache_duration
)
def test_init_initializes_state_holder(self) -> None:
state_holder = self._create_state_holder()
self.assert_state_empty(state_holder.state)
self._mock_backend.assert_not_called()
def test_sync_gets_empty_state_if_backend_state_does_not_exist(self) -> None:
state_holder = self._create_state_holder()
has_set = state_holder.sync()
self.assertIsNone(has_set)
self.assert_state_empty(state_holder.state)
self.assertEqual(self._mock_backend.get_state.call_count, 1)
self.assertEqual(self._mock_backend.set_state.call_count, 0)
def test_sync_gets_backend_state_if_local_state_is_clean(self) -> None:
state_holder = self._create_state_holder()
expected_state = self._create_state()
for attempt in range(1, 4):
with self.subTest(attempt=attempt):
expected_state.round = attempt
self._backend.set_state_internal(expected_state)
has_set = state_holder.sync()
self.assertIsNone(has_set)
self.assert_state_equal(state_holder.state, expected_state)
self.assertEqual(self._mock_backend.get_state.call_count, 1)
self.assertEqual(self._mock_backend.set_state.call_count, 0)
self._mock_backend.reset_mock()
def test_sync_gets_backend_state_if_local_state_is_old_and_dirty(self) -> None:
state_holder = self._create_state_holder()
expected_state = self._create_state()
for attempt in range(1, 4):
with self.subTest(attempt=attempt):
self._backend.set_state_internal(expected_state) # Increment token.
state_holder.state.round = attempt
state_holder.mark_dirty()
has_set = state_holder.sync()
self.assertFalse(has_set)
self.assert_state_equal(state_holder.state, expected_state)
self.assertEqual(self._mock_backend.get_state.call_count, 0)
self.assertEqual(self._mock_backend.set_state.call_count, 1)
self._mock_backend.reset_mock()
def test_sync_sets_backend_state_if_local_state_is_new_and_dirty(self) -> None:
state_holder = self._create_state_holder()
for attempt in range(1, 4):
with self.subTest(attempt=attempt):
state_holder.state.round = attempt
state_holder.mark_dirty()
has_set = state_holder.sync()
self.assertTrue(has_set)
expected_state = self._backend.get_state_internal()
self.assert_state_equal(state_holder.state, expected_state)
self.assertEqual(self._mock_backend.get_state.call_count, 0)
self.assertEqual(self._mock_backend.set_state.call_count, 1)
self._mock_backend.reset_mock()
def test_sync_uses_cached_state_if_cache_duration_is_specified(self) -> None:
state = self._create_state()
self._backend.set_state_internal(state)
with patch(
"torch.distributed.elastic.rendezvous.dynamic_rendezvous.time"
) as mock_time:
for cache_duration in [1, 5, 10]:
with self.subTest(cache_duration=cache_duration):
self._cache_duration = cache_duration
state_holder = self._create_state_holder()
mock_time.monotonic.return_value = 5
state_holder.sync()
has_set = state_holder.sync()
self.assertIsNone(has_set)
self.assertEqual(self._mock_backend.get_state.call_count, 1)
self.assertEqual(self._mock_backend.set_state.call_count, 0)
mock_time.monotonic.return_value = 5 + self._cache_duration
state_holder.sync()
has_set = state_holder.sync()
self.assertIsNone(has_set)
self.assertEqual(self._mock_backend.get_state.call_count, 1)
self.assertEqual(self._mock_backend.set_state.call_count, 0)
self._mock_backend.get_state.reset_mock()
def test_sync_gets_backend_state_if_cached_state_has_expired(self) -> None:
state = self._create_state()
self._backend.set_state_internal(state)
with patch(
"torch.distributed.elastic.rendezvous.dynamic_rendezvous.time"
) as mock_time:
self._cache_duration = 1
state_holder = self._create_state_holder()
mock_time.monotonic.return_value = 5
state_holder.sync()
has_set = state_holder.sync()
self.assertIsNone(has_set)
self.assertEqual(self._mock_backend.get_state.call_count, 1)
self.assertEqual(self._mock_backend.set_state.call_count, 0)
mock_time.monotonic.return_value = 5 + self._cache_duration + 0.01
state_holder.sync()
has_set = state_holder.sync()
self.assertIsNone(has_set)
self.assertEqual(self._mock_backend.get_state.call_count, 2)
self.assertEqual(self._mock_backend.set_state.call_count, 0)
def test_sync_sanitizes_state(self) -> None:
state = self._create_state()
expected_state = copy.deepcopy(state)
dead_node1 = _NodeDesc("dead1", 1, 1)
dead_node2 = _NodeDesc("dead2", 1, 1)
dead_node3 = _NodeDesc("dead3", 1, 1)
dead_node4 = _NodeDesc("dead4", 1, 1)
dead_node5 = _NodeDesc("dead5", 1, 1)
state.last_heartbeats[dead_node1] = self._now - timedelta(seconds=91)
state.last_heartbeats[dead_node2] = self._now - timedelta(seconds=100)
state.last_heartbeats[dead_node3] = self._now - timedelta(seconds=110)
state.last_heartbeats[dead_node4] = self._now - timedelta(seconds=120)
state.last_heartbeats[dead_node5] = self._now - timedelta(seconds=130)
state.participants[dead_node1] = 0
state.participants[dead_node2] = 0
state.participants[dead_node3] = 0
state.wait_list.add(dead_node4)
state.wait_list.add(dead_node5)
self._backend.set_state_internal(state)
state_holder = self._create_state_holder()
state_holder.sync()
self.assert_state_equal(state_holder.state, expected_state)
def test_sync_sanitizes_state_if_no_participants_is_left(self) -> None:
state = self._create_state()
expected_state = copy.deepcopy(state)
for node in state.last_heartbeats:
state.last_heartbeats[node] = self._now - timedelta(seconds=100)
expected_state.complete = False
expected_state.round = 1000
expected_state.participants = {}
expected_state.wait_list = set()
expected_state.last_heartbeats = {}
self._backend.set_state_internal(state)
state_holder = self._create_state_holder()
state_holder.sync()
self.assert_state_equal(state_holder.state, expected_state)
def test_sync_raises_error_if_backend_state_is_corrupt(self) -> None:
self._backend.corrupt_state()
state_holder = self._create_state_holder()
with self.assertRaisesRegex(
RendezvousStateError,
r"^The rendezvous state is corrupt. See inner exception for details.$",
):
state_holder.sync()
class FakeRendezvousStateHolder(_RendezvousStateHolder):
_state: _RendezvousState
_dirty: Optional[bool]
def __init__(self) -> None:
self._state = _RendezvousState()
self._dirty = None
@property
def state(self) -> _RendezvousState:
return self._state
@state.setter
def state(self, value) -> None:
self._state = value
def sync(self) -> Optional[bool]:
self._dirty, dirty = None, self._dirty
return dirty
def mark_dirty(self) -> None:
self._dirty = True
class DistributedRendezvousOpExecutorTest(TestCase, CustomAssertMixin):
def setUp(self) -> None:
self._node = _NodeDesc("this_node", 1, 1)
self._state_holder = FakeRendezvousStateHolder()
mock_sync = MagicMock(wraps=self._state_holder.sync)
mock_mark = MagicMock(wraps=self._state_holder.mark_dirty)
self._mock_state_holder = Mock()
self._mock_state_holder.sync = mock_sync
self._mock_state_holder.mark = mock_mark
setattr(self._state_holder, "sync", mock_sync) # noqa: B010
setattr(self._state_holder, "mark_dirty", mock_mark) # noqa: B010
self._state = self._state_holder.state
self._min_nodes = 1
self._max_nodes = 1
self._timeout = RendezvousTimeout()
self._now = datetime(2000, 1, 1, hour=0, minute=0)
self._datetime_patch = patch(
"torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime"
)
mock_datetime = self._datetime_patch.start()
mock_datetime.utcnow.return_value = self._now
def tearDown(self) -> None:
self._datetime_patch.stop()
def _create_settings(self) -> RendezvousSettings:
return RendezvousSettings(
run_id="dummy_run_id",
min_nodes=self._min_nodes,
max_nodes=self._max_nodes,
timeout=self._timeout,
keep_alive_interval=timedelta(seconds=30),
keep_alive_max_attempt=3,
)
def _create_op_executor(
self, settings: Optional[RendezvousSettings] = None
) -> _DistributedRendezvousOpExecutor:
self._state_holder.state = self._state
if settings is None:
settings = self._create_settings()
return _DistributedRendezvousOpExecutor(
self._node, self._state_holder, settings
)
def _run_action(self, action: _Action) -> None:
op_executor = self._create_op_executor()
op = MagicMock(side_effect=[action, _Action.FINISH])
op_executor.run(op, deadline=1)
def _assert_action(self, action: _Action, expected_state: _RendezvousState) -> None:
self._run_action(action)
self.assert_state_equal(self._state, expected_state)
self.assertListEqual(
self._mock_state_holder.mock_calls, [call.sync(), call.mark(), call.sync()]
)
def test_run_passes_expected_context_and_deadline_to_state_handler(self) -> None:
settings = self._create_settings()
op_executor = self._create_op_executor(settings)
op = MagicMock(return_value=_Action.FINISH)
op_executor.run(op, deadline=3)
ctx, deadline = op.call_args[0] # args
self.assertIs(ctx.node, self._node)
self.assertIs(ctx.state, self._state)
self.assertIs(ctx.settings, settings)
self.assertEqual(deadline, 3)
def test_run_keeps_alive(self) -> None:
expected_state = _RendezvousState()
expected_state.last_heartbeats[self._node] = self._now
self._assert_action(_Action.KEEP_ALIVE, expected_state)
def test_run_adds_to_participants(self) -> None:
expected_state = _RendezvousState()
expected_state.participants[self._node] = 0
expected_state.last_heartbeats[self._node] = self._now
self._min_nodes = 2
self._max_nodes = 2
self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state)
def test_run_adds_to_participants_if_node_was_in_waitlist(self) -> None:
self._state.wait_list.add(self._node)
expected_state = _RendezvousState()
expected_state.participants[self._node] = 0
expected_state.last_heartbeats[self._node] = self._now
self._min_nodes = 2
self._max_nodes = 2
self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state)
def _add_participants(
self, num_participants: int, state: _RendezvousState, ranked: bool = False
) -> None:
for i in range(num_participants):
if ranked:
node = _NodeDesc(f"dummy{i}", 1, 1)
rank = i
else:
node = _NodeDesc(
f"dummy{num_participants - i - 1}", 1, 1
) # Add in reverse.
rank = 0
state.participants[node] = rank
state.last_heartbeats[node] = self._now
def test_run_adds_to_participants_and_starts_last_call_if_min_nodes_is_reached(
self,
) -> None:
for num_participants in range(3):
self._state = _RendezvousState()
self._add_participants(num_participants, self._state)
self._state.wait_list.add(self._node)
expected_state = _RendezvousState()
self._add_participants(num_participants, expected_state)
expected_state.participants[self._node] = 0
expected_state.last_heartbeats[self._node] = self._now
expected_state.deadline = self._now + self._timeout.last_call
with self.subTest(num_participants=num_participants):
self._min_nodes = num_participants + 1
self._max_nodes = num_participants + 2
self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state)
self._mock_state_holder.reset_mock()
def test_run_adds_to_participants_and_completes_rendezvous_if_max_nodes_is_reached(
self,
) -> None:
for min_max_nodes_equal in [False, True]:
for num_participants in range(3):
rank = num_participants
self._state = _RendezvousState()
self._add_participants(num_participants, self._state)
self._state.wait_list.add(self._node)
self._state.deadline = self._now + self._timeout.last_call
expected_state = _RendezvousState()
self._add_participants(num_participants, expected_state, ranked=True)
expected_state.participants[self._node] = rank
expected_state.last_heartbeats[self._node] = self._now
expected_state.complete = True
expected_state.deadline = None
with self.subTest(num_participants=num_participants):
self._min_nodes = num_participants + 1 if min_max_nodes_equal else 0
self._max_nodes = num_participants + 1
self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state)
self._mock_state_holder.reset_mock()
def test_run_adds_to_waitlist(self) -> None:
expected_state = _RendezvousState()
expected_state.wait_list.add(self._node)
expected_state.last_heartbeats[self._node] = self._now
self._assert_action(_Action.ADD_TO_WAIT_LIST, expected_state)
def test_run_removes_from_participants(self) -> None:
for complete, last_call_deadline in [(False, self._now), (True, None)]:
self._state = _RendezvousState()
self._add_participants(2, self._state)
self._state.participants[self._node] = 0
self._state.last_heartbeats[self._node] = self._now
self._state.complete = complete
self._state.deadline = last_call_deadline
self._state.round = 1
expected_state = _RendezvousState()
self._add_participants(2, expected_state)
expected_state.complete = complete
expected_state.deadline = last_call_deadline
expected_state.round = 1
with self.subTest(complete=complete):
self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state)
self._mock_state_holder.reset_mock()
def test_run_removes_from_participants_and_moves_to_next_round_if_node_is_last_participant(
self,
) -> None:
self._state.participants[self._node] = 0
self._state.last_heartbeats[self._node] = self._now
self._state.complete = True
self._state.round = 1
expected_state = _RendezvousState()
expected_state.complete = False
expected_state.round = 2
self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state)
def test_run_removes_from_participants_and_clears_last_call_if_rendezvous_has_less_than_min_nodes(
self,
) -> None:
self._add_participants(2, self._state)
self._state.participants[self._node] = 0
self._state.last_heartbeats[self._node] = self._now
self._state.deadline = self._now
expected_state = _RendezvousState()
self._add_participants(2, expected_state)
self._min_nodes = 3
self._max_nodes = 4
self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state)
def test_run_removes_from_waitlist(self) -> None:
self._state.wait_list.add(self._node)
self._state.last_heartbeats[self._node] = self._now
expected_state = _RendezvousState()
self._assert_action(_Action.REMOVE_FROM_WAIT_LIST, expected_state)
def test_run_marks_rendezvous_closed(self) -> None:
expected_state = _RendezvousState()
expected_state.closed = True
self._assert_action(_Action.MARK_RENDEZVOUS_CLOSED, expected_state)
def test_run_raises_error_if_rendezvous_is_closed(self) -> None:
with self.assertRaises(RendezvousClosedError):
self._run_action(_Action.ERROR_CLOSED)
self.assertListEqual(self._mock_state_holder.mock_calls, [call.sync()])
def test_run_raises_error_if_operation_timed_out(self) -> None:
with self.assertRaises(RendezvousTimeoutError):
self._run_action(_Action.ERROR_TIMEOUT)
self.assertListEqual(self._mock_state_holder.mock_calls, [call.sync()])
def test_run_delays_execution_if_sync_requested(self) -> None:
with patch(
"torch.distributed.elastic.rendezvous.dynamic_rendezvous._delay"
) as mock_delay:
self._run_action(_Action.SYNC)
mock_delay.assert_called_once_with(seconds=1)
self.assertListEqual(
self._mock_state_holder.mock_calls, [call.sync(), call.sync()]
)
class AbstractTestRendezvousOp(ABC):
assertEqual: Callable
def setUp(self) -> None:
self._node = _NodeDesc("this_node", 1, 1)
self._min_nodes = 1
self._max_nodes = 2
self._keep_alive_interval = timedelta(seconds=30)
self._state = _RendezvousState()
self._state.participants[_NodeDesc("dummy1", 1, 1)] = 1
self._now = datetime(2000, 1, 1, hour=0, minute=0)
self._deadline = 10
self._datetime_patch = patch(
"torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime"
)
mock_datetime = self._datetime_patch.start()
mock_datetime.utcnow.return_value = self._now
self._time_patch = patch(
"torch.distributed.elastic.rendezvous.dynamic_rendezvous.time"
)
mock_time = self._time_patch.start()
mock_time.monotonic.return_value = self._deadline
def tearDown(self) -> None:
self._time_patch.stop()
self._datetime_patch.stop()
def _get_next_action(self) -> _Action:
op = self._create_op()
settings = RendezvousSettings(
run_id="dummy_run_id",
min_nodes=self._min_nodes,
max_nodes=self._max_nodes,
timeout=RendezvousTimeout(),
keep_alive_interval=self._keep_alive_interval,
keep_alive_max_attempt=3,
)
ctx = _RendezvousContext(self._node, self._state, settings)
return op(ctx, self._deadline)
@abstractmethod
def _create_op(self) -> Callable:
pass
def _assert_action(self, expected_action) -> None:
action = self._get_next_action()
self.assertEqual(action, expected_action)
class TestRendezvousExitOp(AbstractTestRendezvousOp, TestCase):
def _create_op(self) -> Callable:
return _RendezvousExitOp()
def test_removes_from_participants_if_node_is_participant(self) -> None:
self._state.participants[self._node] = 1
self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS)
def test_raises_timeout_if_deadline_exceeded(self) -> None:
self._deadline = 0
self._state.participants[self._node] = 1
self._assert_action(_Action.ERROR_TIMEOUT)
def test_finishes_if_node_is_not_participant(self) -> None:
self._assert_action(_Action.FINISH)
class TestRendezvousJoinOp(AbstractTestRendezvousOp, TestCase):
def _create_op(self) -> Callable:
return _RendezvousJoinOp()
def test_raises_closed_if_rendezvous_is_closed(self) -> None:
self._state.closed = True
self._assert_action(_Action.ERROR_CLOSED)
def test_finishes_if_rendezvous_is_complete_and_node_is_participant(self) -> None:
self._state.participants[self._node] = 0
self._state.complete = True
self._assert_action(_Action.FINISH)
def _assert_waits_rendezvous_completion(self) -> None:
keep_alive_time = self._now - self._keep_alive_interval
for delta, expected_action in [
(timedelta(seconds=0), _Action.KEEP_ALIVE),
(timedelta(seconds=1), _Action.SYNC),
]:
self._state.last_heartbeats[self._node] = keep_alive_time + delta
self._assert_action(expected_action)
def test_treat_as_redundancy_for_next_rendezvous_if_rendezvous_is_complete(
self,
) -> None:
self._max_nodes = 1
self._state.complete = True
self._assert_action(_Action.ADD_TO_REDUNDANCY_LIST)
def test_waits_next_round_if_rendezvous_is_complete_and_node_is_redundant(
self,
) -> None:
self._state.redundancy_list.add(self._node)
self._max_nodes = 1
self._state.complete = True
self._assert_waits_rendezvous_completion()
def test_remove_from_rednundancy_list(self) -> None:
self._state.redundancy_list.add(self._node)
self._max_nodes = 2
self._state.complete = True
self._assert_action(_Action.REMOVE_FROM_REDUNDANCY_LIST)
def test_waits_next_round_if_rendezvous_is_complete_and_node_is_in_wait_list(
self,
) -> None:
self._state.wait_list.add(self._node)
self._state.complete = True
self._assert_waits_rendezvous_completion()
def test_adds_to_wait_list_if_rendezvous_is_complete_and_num_nodes_is_less_than_max_nodes(
self,
) -> None:
self._state.complete = True
self._assert_action(_Action.ADD_TO_WAIT_LIST)
def test_waits_rendezvous_to_complete_if_node_is_participant(self) -> None:
self._max_nodes = 3
self._state.participants[self._node] = 0
self._state.deadline = self._now
self._assert_waits_rendezvous_completion()
def test_marks_rendezvous_complete_if_node_is_participant_and_last_call_deadline_exceeded(
self,
) -> None:
self._max_nodes = 3
self._state.participants[self._node] = 0
self._state.deadline = self._now - timedelta(seconds=1)
self._assert_action(_Action.MARK_RENDEZVOUS_COMPLETE)
def test_adds_to_participants(self) -> None:
self._assert_action(_Action.ADD_TO_PARTICIPANTS)
def test_raises_timeout_if_deadline_exceeded(self) -> None:
self._deadline = 0
self._assert_action(_Action.ERROR_TIMEOUT)
def test_raises_timeout_if_rollback_deadline_exceeded_and_node_is_participant(
self,
) -> None:
self._deadline = 0
self._state.participants[self._node] = 0
self._assert_action(_Action.ERROR_TIMEOUT)
def test_raises_timeout_if_rollback_deadline_exceeded_and_node_is_in_wait_list(
self,
) -> None:
self._deadline = 0
self._state.wait_list.add(self._node)
self._assert_action(_Action.ERROR_TIMEOUT)
def test_removes_from_participants_if_timed_out_but_rollback_deadline_is_not_reached(
self,
) -> None:
self._deadline = 5
self._state.participants[self._node] = 0
self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS)
def test_removes_from_wait_list_if_timed_out_but_rollback_deadline_is_not_reached(
self,
) -> None:
self._deadline = 5
self._state.wait_list.add(self._node)
self._assert_action(_Action.REMOVE_FROM_WAIT_LIST)
def test_no_timeout_for_redundant_node(self) -> None:
self._max_nodes = 1
self._deadline = 0
self._state.complete = True
self._state.redundancy_list.add(self._node)
self._assert_action(_Action.SYNC)
def test_keep_alive_for_redundant_node(self) -> None:
self._deadline = 0
self._max_nodes = 1
self._state.complete = True
self._state.redundancy_list.add(self._node)
keep_alive_time = self._now - self._keep_alive_interval
self._state.last_heartbeats[self._node] = keep_alive_time
self._assert_action(_Action.KEEP_ALIVE)
class TestRendezvousCloseOp(AbstractTestRendezvousOp, TestCase):
def _create_op(self) -> Callable:
return _RendezvousCloseOp()
def test_finishes_if_rendezvous_is_closed(self) -> None:
self._state.closed = True
self._assert_action(_Action.FINISH)
def test_raises_timeout_if_deadline_exceeded(self) -> None:
self._deadline = 0
self._assert_action(_Action.ERROR_TIMEOUT)
def test_marks_rendezvous_closed(self) -> None:
self._assert_action(_Action.MARK_RENDEZVOUS_CLOSED)
class TestRendezvousKeepAliveOp(AbstractTestRendezvousOp, TestCase):
def _create_op(self) -> Callable:
return _RendezvousKeepAliveOp()
def test_updates_keep_alive_if_needed(self) -> None:
keep_alive_time = self._now - self._keep_alive_interval
for delta in [timedelta(seconds=0), timedelta(seconds=-1)]:
with self.subTest(delta=delta):
self._state.last_heartbeats[self._node] = keep_alive_time + delta
self._assert_action(_Action.KEEP_ALIVE)
def test_raises_timeout_if_deadlined_exceeded(self) -> None:
self._deadline = 0
self._state.last_heartbeats[self._node] = self._now - self._keep_alive_interval
self._assert_action(_Action.ERROR_TIMEOUT)
def test_finishes_if_no_keep_alive_update_is_needed(self) -> None:
delta = timedelta(seconds=1)
self._state.last_heartbeats[self._node] = (
self._now - self._keep_alive_interval + delta
)
self._assert_action(_Action.FINISH)
class DummyStore(Store):
@property
def port(self) -> int:
return TEST_PORT
class DynamicRendezvousHandlerTest(TestCase):
def setUp(self) -> None:
self._node = _NodeDesc("this_node", 1, 1)
self._min_nodes = 1
self._max_nodes = 1
self._join_timeout: Optional[timedelta] = None
self._close_timeout: Optional[timedelta] = None
self._heartbeat_timeout: Optional[timedelta] = None
self._keep_alive_interval = timedelta(seconds=30)
self._store = DummyStore()
self._mock_store_get = MagicMock(return_value=b"123")
self._mock_store_set = MagicMock()
setattr(self._store, "get", self._mock_store_get) # noqa: B010
setattr(self._store, "set", self._mock_store_set) # noqa: B010
self._state_holder = FakeRendezvousStateHolder()
self._mock_sync = MagicMock(wraps=self._state_holder.sync)
setattr(self._state_holder, "sync", self._mock_sync) # noqa: B010
self._state = self._state_holder.state
self._tcp_store_mock = DummyStore()
patcher = patch.object(
DynamicRendezvousHandler,
"_create_tcp_store_server",
return_value=self._tcp_store_mock,
)
patcher.start()
self.addCleanup(patcher.stop)
def _create_handler(self) -> DynamicRendezvousHandler:
settings = RendezvousSettings(
run_id="dummy_run_id",
min_nodes=self._min_nodes,
max_nodes=self._max_nodes,
timeout=RendezvousTimeout(
join=self._join_timeout,
close=self._close_timeout,
heartbeat=self._heartbeat_timeout,
),
keep_alive_interval=self._keep_alive_interval,
keep_alive_max_attempt=3,
)
self._state_holder.state = self._state
return DynamicRendezvousHandler(
self._node, settings, "dummy_backend", self._store, self._state_holder
)
def test_share_store_creates_tcp_store(self):
handler = self._create_handler()
shared_store_info = RendezvousStoreInfo(TEST_ADDR, TEST_PORT)
with patch.object(RendezvousStoreInfo, "build", return_value=shared_store_info):
rdzv_info = handler.next_rendezvous()
self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, TEST_ADDR)
self.assertEqual(rdzv_info.bootstrap_store_info.master_port, TEST_PORT)
self.assertEqual(handler._shared_tcp_store_server, self._tcp_store_mock)
rdzv_info = handler.next_rendezvous()
self.assertEqual(handler._shared_tcp_store_server, self._tcp_store_mock)
def test_share_store_when_tcp_store(self):
handler = self._create_handler()
class CustomPrefixStore(Mock):
def get(self, key):
return (
TEST_ADDR.encode("utf-8")
if key == "MASTER_ADDR"
else bytes(str(TEST_PORT), "utf-8")
)
def set(self, key, value):
pass
with patch.object(dist, "PrefixStore", new=CustomPrefixStore):
handler._store = Mock(spec=dist.TCPStore)
type(handler._store).host = PropertyMock(return_value=TEST_ADDR)
type(handler._store).port = PropertyMock(return_value=TEST_PORT - 1)
rdzv_info = handler.next_rendezvous()
self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, TEST_ADDR)
self.assertEqual(rdzv_info.bootstrap_store_info.master_port, TEST_PORT)
self.assertNotEqual(handler._shared_tcp_store_server, handler._store)
rdzv_info = handler.next_rendezvous()
self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, TEST_ADDR)
self.assertEqual(rdzv_info.bootstrap_store_info.master_port, TEST_PORT)
self.assertNotEqual(handler._shared_tcp_store_server, handler._store)
@patch("torch.distributed.elastic.rendezvous.dynamic_rendezvous._delay")
def test_next_rendezvous_skews_the_first_join_attempt(self, mock_delay) -> None:
for round, expected_call_count in [(0, True), (1, False)]:
with self.subTest(round=round):
self._state.round = round
handler = self._create_handler()
handler.next_rendezvous()
self.assertEqual(mock_delay.call_count, expected_call_count)
mock_delay.reset_mock()
def test_next_rendezvous_returns_expected_value(self) -> None:
self._state.participants[_NodeDesc("dummy1", 1, 1)] = 0
self._state.participants[_NodeDesc("dummy2", 1, 1)] = 0
self._max_nodes = 3
handler = self._create_handler()
rdzv_info = handler.next_rendezvous()
self.assertEqual(rdzv_info.rank, 2)
self.assertEqual(rdzv_info.world_size, 3)
_ = rdzv_info.store.get("dummy_key")
self._mock_store_get.assert_called_with(
"torch.rendezvous.dummy_run_id.0/dummy_key"
)
def test_next_rendezvous_respects_the_requested_timeout(self) -> None:
self._mock_sync.side_effect = lambda: time.sleep(0.3)
self._join_timeout = timedelta(seconds=0.2)
handler = self._create_handler()
with self.assertRaises(RendezvousTimeoutError):
handler.next_rendezvous()
def test_next_rendezvous_moves_to_next_round_if_called_repeatedly(self) -> None:
handler = self._create_handler()
for i in range(4):
handler.next_rendezvous()
self.assertEqual(self._state.round, i)
def test_is_closed_returns_expected_value(self) -> None:
for closed in [False, True]:
with self.subTest(closed=closed):
self._state.closed = closed
handler = self._create_handler()
self.assertEqual(handler.is_closed(), closed)
self._mock_sync.assert_called_once()
self._mock_sync.reset_mock()
@patch("torch.distributed.elastic.events.record_rdzv_event")
def test_is_closed_records_and_raises_exceptions(self, record_mock) -> None:
self._mock_sync.side_effect = RendezvousError("test error")
handler = self._create_handler()
with self.assertRaises(RendezvousError):
handler.is_closed()
record_mock.assert_called_once()
def test_set_closed_closes_rendezvous(self) -> None:
handler = self._create_handler()
handler.set_closed()
self.assertTrue(self._state.closed)
def test_set_closed_respects_the_requested_timeout(self) -> None:
self._mock_sync.side_effect = lambda: time.sleep(0.3)
self._close_timeout = timedelta(seconds=0.2)
handler = self._create_handler()
with self.assertRaises(RendezvousTimeoutError):
handler.set_closed()
def test_set_closed_can_be_called_multiple_times(self) -> None:
handler = self._create_handler()
handler.set_closed()
handler.set_closed()
self.assertTrue(self._state.closed)
@patch("torch.distributed.elastic.events.record_rdzv_event")
def test_set_closed_records_and_raises_exceptions(self, record_mock) -> None:
with patch.object(DynamicRendezvousHandler, "_close") as close_mock:
close_mock.side_effect = RendezvousError("test error")
handler = self._create_handler()
with self.assertRaises(RendezvousError):
handler.set_closed()
record_mock.assert_called_once()
def test_num_nodes_waiting_returns_expected_value(self) -> None:
self._state.wait_list.add(_NodeDesc("dummy1", 1, 1))
self._state.wait_list.add(_NodeDesc("dummy2", 1, 1))
handler = self._create_handler()
self.assertEqual(handler.num_nodes_waiting(), 2)
self._mock_sync.assert_called_once()
@patch("torch.distributed.elastic.events.record_rdzv_event")
def test_num_nodes_waiting_records_and_raises_exceptions(self, record_mock) -> None:
self._mock_sync.side_effect = RendezvousError("test error")
handler = self._create_handler()
with self.assertRaises(RendezvousError):
handler.num_nodes_waiting()
record_mock.assert_called_once()
def test_shutdown_closes_rendezvous_and_returns_true(self) -> None:
handler = self._create_handler()
result = handler.shutdown()
self.assertTrue(result)
self.assertTrue(self._state.closed)
def test_shutdown_returns_false_if_rendezvous_cannot_be_closed(self) -> None:
self._mock_sync.side_effect = [RendezvousError]
handler = self._create_handler()
result = handler.shutdown()
self.assertFalse(result)
def test_shutdown_can_be_called_multiple_times(self) -> None:
handler = self._create_handler()
handler.shutdown()
handler.shutdown()
self.assertTrue(self._state.closed)
@patch("torch.distributed.elastic.events.record_rdzv_event")
def test_shutdown_records_and_raises_exceptions(self, record_mock) -> None:
with patch.object(DynamicRendezvousHandler, "_close") as close_mock:
close_mock.side_effect = RuntimeError("test error")
handler = self._create_handler()
with self.assertRaises(RuntimeError):
handler.shutdown()
record_mock.assert_called_once()
@patch("torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime")
def test_keep_alive_updates_last_heartbeat(self, mock_datetime) -> None:
now = datetime(2000, 1, 1, hour=0, minute=0)
mock_datetime.utcnow.return_value = now
self._state.last_heartbeats[self._node] = now - (self._keep_alive_interval * 2)
handler = self._create_handler()
handler._keep_alive()
self.assertEqual(self._state.last_heartbeats[self._node], now)
def _assert_keep_alive_swallows_rendezvous_errors(self) -> None:
last_heartbeat_time = datetime.now(timezone.utc) - (
self._keep_alive_interval * 2
)
self._state.last_heartbeats[self._node] = last_heartbeat_time
handler = self._create_handler()
handler._keep_alive()
self.assertEqual(self._state.last_heartbeats[self._node], last_heartbeat_time)
def test_keep_alive_swallows_rendezvous_errors(self) -> None:
self._mock_sync.side_effect = [RendezvousError]
self._assert_keep_alive_swallows_rendezvous_errors()
def test_keep_alive_respects_the_requested_timeout(self) -> None:
self._mock_sync.side_effect = lambda: time.sleep(0.3)
self._heartbeat_timeout = timedelta(seconds=0.2)
self._assert_keep_alive_swallows_rendezvous_errors()
def test_keep_alive_thread_is_started_with_next_rendezvous_and_stopped_with_shutdown(
self,
) -> None:
self._node = _NodeDesc("this_node", 1, 2)
name = "RendezvousKeepAliveTimer_2"
handler = self._create_handler()
self.assertTrue(all(t.name != name for t in threading.enumerate()))
handler.next_rendezvous()
self.assertTrue(any(t.name == name for t in threading.enumerate()))
handler.shutdown()
self.assertTrue(all(t.name != name for t in threading.enumerate()))
def test_keep_alive_thread_is_started_with_next_rendezvous_and_stopped_with_finalizer(
self,
) -> None:
self._node = _NodeDesc("this_node", 1, 3)
name = "RendezvousKeepAliveTimer_3"
handler = self._create_handler()
self.assertTrue(all(t.name != name for t in threading.enumerate()))
handler.next_rendezvous()
self.assertTrue(any(t.name == name for t in threading.enumerate()))
del handler
self.assertTrue(all(t.name != name for t in threading.enumerate()))
class DummyRendezvousBackend(RendezvousBackend):
@property
def name(self):
return "dummy_backend"
def get_state(self):
return None
def set_state(self, state, token):
return None
class DynamicRendezvousHandlerFromBackendTest(TestCase):
def setUp(self) -> None:
self._run_id = "dummy_run_id"
self._store = DummyStore()
self._backend = DummyRendezvousBackend()
self._min_nodes = 3
self._max_nodes = 6
self._timeout: Optional[RendezvousTimeout] = RendezvousTimeout()
def _create_handler(self) -> DynamicRendezvousHandler:
return DynamicRendezvousHandler.from_backend(
run_id=self._run_id,
store=self._store,
backend=self._backend,
min_nodes=self._min_nodes,
max_nodes=self._max_nodes,
timeout=self._timeout,
)
def test_init_initializes_handler(self) -> None:
handler = self._create_handler()
self.assertEqual(handler.get_backend(), self._backend.name)
self.assertEqual(handler.get_run_id(), self._run_id)
self.assertEqual(handler.settings.run_id, self._run_id)
self.assertEqual(handler.settings.min_nodes, self._min_nodes)
self.assertEqual(handler.settings.max_nodes, self._max_nodes)
if self._timeout is None:
self.assertIsNotNone(handler.settings.timeout)
else:
self.assertIs(handler.settings.timeout, self._timeout)
def test_init_initializes_handler_if_timeout_is_not_specified(self) -> None:
self._timeout = None
self.test_init_initializes_handler()
def test_init_initializes_handler_if_min_and_max_nodes_are_equal(self) -> None:
self._min_nodes = 3
self._max_nodes = 3
self.test_init_initializes_handler()
def test_init_raises_error_if_min_nodes_is_not_positive(self) -> None:
for num in [0, -10]:
with self.subTest(min_nodes=num):
self._min_nodes = num
with self.assertRaisesRegex(
ValueError,
rf"^The minimum number of nodes \({num}\) must be greater than zero.$",
):
self._create_handler()
def test_init_raises_error_if_max_nodes_is_less_than_min(self) -> None:
self._min_nodes = 3
self._max_nodes = 2
with self.assertRaisesRegex(
ValueError,
rf"^The maximum number of nodes \({self._max_nodes}\) must be greater than or equal to "
"the minimum number of nodes "
rf"\({self._min_nodes}\).$",
):
self._create_handler()
class CreateHandlerTest(TestCase):
def setUp(self) -> None:
self._store = DummyStore()
self._backend = DummyRendezvousBackend()
self._params = RendezvousParameters(
backend=self._backend.name,
endpoint="dummy_endpoint",
run_id="dummy_run_id",
min_nodes=3,
max_nodes=6,
join_timeout="50",
last_call_timeout="60",
close_timeout="70",
)
self._expected_timeout = RendezvousTimeout(
timedelta(seconds=50), timedelta(seconds=60), timedelta(seconds=70)
)
def test_create_handler_returns_handler(self) -> None:
handler = create_handler(self._store, self._backend, self._params)
self.assertEqual(handler.get_backend(), self._backend.name)
self.assertEqual(handler.get_run_id(), self._params.run_id)
self.assertEqual(handler.settings.min_nodes, self._params.min_nodes)
self.assertEqual(handler.settings.max_nodes, self._params.max_nodes)
self.assertEqual(handler.settings.timeout.join, self._expected_timeout.join)
self.assertEqual(
handler.settings.timeout.last_call, self._expected_timeout.last_call
)
self.assertEqual(handler.settings.timeout.close, self._expected_timeout.close)
def test_create_handler_returns_handler_if_timeout_is_not_specified(self) -> None:
del self._params.config["join_timeout"]
del self._params.config["last_call_timeout"]
del self._params.config["close_timeout"]
self._expected_timeout = RendezvousTimeout()
self.test_create_handler_returns_handler()
@patch("torch.distributed.elastic.events.record_rdzv_event")
def test_create_handler_records_and_raises_exceptions(self, record_mock) -> None:
with patch.object(DynamicRendezvousHandler, "from_backend") as from_mock:
from_mock.side_effect = RendezvousError("test error")
with self.assertRaises(RendezvousError):
create_handler(self._store, self._backend, self._params)
record_mock.assert_called_once()
def test_create_handler_rdzv_local_addr(self) -> None:
params = RendezvousParameters(
backend=self._backend.name,
endpoint="dummy_endpoint",
run_id="dummy_run_id",
min_nodes=1,
max_nodes=1,
join_timeout="50",
last_call_timeout="60",
close_timeout="70",
local_addr="127.0.0.2",
)
store = HashStore()
handler = create_handler(store, self._backend, params)
rdzv_info = handler.next_rendezvous()
self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "127.0.0.2")
def _ignore_exception(exception_type: Exception, fn: Callable):
try:
fn()
except exception_type:
pass
def _wait_for(condition, timeout=10, interval=1, name=None):
def _wait_while():
while True:
if condition():
break
else:
time.sleep(interval)
wait_thread = threading.Thread(target=_wait_while, name=name)
wait_thread.start()
wait_thread.join(timeout=timeout)
class _CapturingThread(threading.Thread):
def __init__(self, target=None, name=None, args=None, kwargs=None):
if args is None:
args = ()
if kwargs is None:
kwargs = {}
threading.Thread.__init__(
self, target=target, args=args, kwargs=kwargs, name=name
)
self._result = None
def run(self):
if self._target is not None:
self._result = self._target(*self._args, **self._kwargs)
def join(self, *args):
threading.Thread.join(self, *args)
return self._result
class IntegrationTest(TestCase):
def setUp(self) -> None:
self._store = HashStore()
self._handlers = []
self._backend = _InMemoryRendezvousBackend()
def tearDown(self) -> None:
for handler in self._handlers:
handler._stop_heartbeats()
def _create_handler(self, **kwargs) -> DynamicRendezvousHandler:
params = {
"backend": self._backend.name,
"endpoint": "dummy_endpoint",
"run_id": "dummy_run_id",
"min_nodes": 2,
"max_nodes": 2,
"join_timeout": "5",
"local_addr": f"127.0.0.{len(self._handlers)}",
}
params.update(**kwargs)
rzdv_params = RendezvousParameters(**params)
handler = create_handler(self._store, self._backend, rzdv_params)
self._handlers.append(handler)
return handler
def test_all_nodes_join_rendezvous(self) -> None:
handler1 = self._create_handler(min_nodes=2, max_nodes=2)
handler2 = self._create_handler(min_nodes=2, max_nodes=2)
handler1_thread = _CapturingThread(target=handler1.next_rendezvous)
handler2_thread = _CapturingThread(target=handler2.next_rendezvous)
handler1_thread.start()
handler2_thread.start()
rdzv_info1: RendezvousInfo = handler1_thread.join()
rdzv_info2: RendezvousInfo = handler2_thread.join()
self.assertEqual(rdzv_info1.store.underlying_store, self._store)
self.assertEqual(rdzv_info2.store.underlying_store, self._store)
self.assertNotEqual(rdzv_info1.rank, rdzv_info2.rank)
self.assertEqual(rdzv_info1.world_size, 2)
self.assertEqual(rdzv_info2.world_size, 2)
def test_redundancy_list(self) -> None:
handler1 = self._create_handler(min_nodes=2, max_nodes=2)
handler2 = self._create_handler(min_nodes=2, max_nodes=2)
handler3 = self._create_handler(min_nodes=2, max_nodes=2)
handler1_thread = _CapturingThread(target=handler1.next_rendezvous)
handler2_thread = _CapturingThread(target=handler2.next_rendezvous)
handler3_thread = _CapturingThread(
target=_ignore_exception,
args=(RendezvousTimeoutError, lambda: handler3.next_rendezvous()),
)
handler1_thread.start()
handler2_thread.start()
# establish successful rendezvous
handler1_thread.join()
handler2_thread.join()
# expect to register in redundancy list
handler3_thread.start()
# wait until the handler3 is registered in the redundancy list
_wait_for(lambda: pickle.loads(self._backend.get_state()[0]).redundancy_list)
state_and_token = self._backend.get_state()
state = pickle.loads(state_and_token[0])
addresses = [node.addr for node in state.redundancy_list]
self.assertListEqual(addresses, ["127.0.0.2"])
def test_redundancy_transition_to_wait_list_then_join_rendezvous(self) -> None:
handler1 = self._create_handler(
min_nodes=1,
max_nodes=2,
)
handler2 = self._create_handler(
min_nodes=1,
max_nodes=2,
keep_alive_interval=timedelta(seconds=1),
)
handler3 = self._create_handler(
min_nodes=1,
max_nodes=2,
)
handler1_thread = _CapturingThread(target=handler1.next_rendezvous)
handler2_thread = _CapturingThread(target=handler2.next_rendezvous)
handler3_thread = _CapturingThread(
target=_ignore_exception,
args=(RendezvousTimeoutError, lambda: handler3.next_rendezvous()),
)
handler1_thread.start()
handler2_thread.start()
# establish successful rendezvous
handler1_thread.join()
handler2_thread.join()
handler3_thread.start()
_wait_for(lambda: pickle.loads(self._backend.get_state()[0]).redundancy_list)
handler2._stop_heartbeats()
_wait_for(
lambda: len(pickle.loads(self._backend.get_state()[0]).participants) == 1
)
_wait_for(
lambda: len(pickle.loads(self._backend.get_state()[0]).wait_list) == 1
)
def test_use_agent_store_is_true_by_default(self):
handler = self._create_handler(
min_nodes=1,
max_nodes=2,
)
self.assertTrue(handler.use_agent_store)
@patch.dict(os.environ, {"TORCH_DISABLE_SHARE_RDZV_TCP_STORE": "1"})
def test_use_agent_store_is_disabled(self):
handler = self._create_handler(
min_nodes=1,
max_nodes=2,
)
self.assertFalse(handler.use_agent_store)
@patch.object(dist, "PrefixStore")
def test_share_tcp_store_from_backend(self, prefix_store_class_mock):
expected_addr = "expected_address"
expected_port = 54231
class CustomPrefixStore(Mock):
def get(self, key):
return (
expected_addr.encode("utf-8")
if key == "MASTER_ADDR"
else bytes(str(expected_port), "utf-8")
)
def set(self, key, value):
pass
prefix_store = CustomPrefixStore(spec=dist.PrefixStore)
prefix_store_class_mock.return_value = prefix_store
tcp_store = Mock(spec=dist.TCPStore)
original_addr = "original_addr"
original_port = TEST_PORT
type(tcp_store).host = PropertyMock(return_value=original_addr)
type(tcp_store).port = PropertyMock(return_value=original_port)
# this will be injected
self._store = tcp_store
handler1 = self._create_handler(min_nodes=2, max_nodes=2)
handler2 = self._create_handler(min_nodes=2, max_nodes=2)
handler1_thread = _CapturingThread(target=handler1.next_rendezvous)
handler2_thread = _CapturingThread(target=handler2.next_rendezvous)
handler1_thread.start()
handler2_thread.start()
rdzv_info1: RendezvousInfo = handler1_thread.join()
rdzv_info2: RendezvousInfo = handler2_thread.join()
self.assertEqual(rdzv_info1.store, prefix_store)
self.assertEqual(rdzv_info2.store, prefix_store)
prefix_store_class_mock.assert_called_with(
"torch.rendezvous.dummy_run_id.0", tcp_store
)
self.assertEqual(
rdzv_info1.bootstrap_store_info, rdzv_info2.bootstrap_store_info
)
self.assertEqual(rdzv_info1.bootstrap_store_info.master_addr, expected_addr)
self.assertEqual(rdzv_info1.bootstrap_store_info.master_port, expected_port)
@patch.dict(os.environ, {"TORCH_DISABLE_SHARE_RDZV_TCP_STORE": "1"})
@patch.object(dist, "PrefixStore")
def test_share_tcp_store_is_disabled(self, prefix_store_class_mock):
prefix_store = Mock()
prefix_store_class_mock.return_value = prefix_store
prefix_store.set.return_value = None
prefix_store.get.return_value = b"123"
tcp_store = Mock(spec=dist.TCPStore)
# this will be injected
self._store = tcp_store
handler1 = self._create_handler(min_nodes=2, max_nodes=2)
handler2 = self._create_handler(min_nodes=2, max_nodes=2)
handler1_thread = _CapturingThread(target=handler1.next_rendezvous)
handler2_thread = _CapturingThread(target=handler2.next_rendezvous)
handler1_thread.start()
handler2_thread.start()
rdzv_info1: RendezvousInfo = handler1_thread.join()
rdzv_info2: RendezvousInfo = handler2_thread.join()
self.assertEqual(rdzv_info1.store, prefix_store)
self.assertEqual(rdzv_info2.store, prefix_store)
prefix_store_class_mock.assert_called_with(
"torch.rendezvous.dummy_run_id.0", self._store
)
self.assertEqual(rdzv_info1.bootstrap_store_info.master_port, 123)
self.assertEqual(rdzv_info2.bootstrap_store_info.master_port, 123)
class _InMemoryRendezvousBackend(RendezvousBackend):
def __init__(self) -> None:
self._lock = threading.Lock()
self._state = None
self._token = None
@property
def name(self):
return "_in_memory_backend"
def get_state(self):
with self._lock:
if self._state is None:
return None
return (self._state, self._token)
return self._state
def set_state(self, state, token):
if state is None:
raise ValueError("State cannot be None.")
with self._lock:
if token is None and self._token is not None:
return None
if self._token != token:
return None
self._state = state
self._token = self._token + 1 if self._token is not None else 0