mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: When workers finish their work TE agent will start `synchronize_barrier` procedure. The barrier will wait for other agents at the end of the execution. There is a race condition may happen: The barrier uses TCPStore which is located on Rank0. When Rank0 finishes the work, other ranks may still be in a process of executing `get_all` method. This means that some of them will fail because the TCPStore will be destroyed. The fix adds additional check on Rank0 process: Rank0 process now waits for all other ranks to finish before terminating the process. Test Plan: unit tests Differential Revision: D35227180 Pull Request resolved: https://github.com/pytorch/pytorch/pull/74931 Approved by: https://github.com/kiukchung
104 lines
3.6 KiB
Python
104 lines
3.6 KiB
Python
#!/usr/bin/env python3
|
|
# Owner(s): ["oncall: r2p"]
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from unittest import mock
|
|
|
|
import torch.distributed.elastic.utils.store as store_util
|
|
from torch.distributed.elastic.utils.logging import get_logger
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class StoreUtilTest(TestCase):
|
|
def test_get_all_rank_0(self):
|
|
store = mock.MagicMock()
|
|
world_size = 3
|
|
store_util.get_all(store, 0, "test/store", world_size)
|
|
# omit empty kwargs, get only key
|
|
actual_set_call_args = [
|
|
call_args[0][0] for call_args in store.set.call_args_list
|
|
]
|
|
self.assertListEqual(["test/store0.FIN"], actual_set_call_args)
|
|
|
|
actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list]
|
|
expected_get_call_args = [
|
|
("test/store0",),
|
|
("test/store1",),
|
|
("test/store2",),
|
|
("test/store0.FIN",),
|
|
("test/store1.FIN",),
|
|
("test/store2.FIN",),
|
|
]
|
|
self.assertListEqual(expected_get_call_args, actual_get_call_args)
|
|
|
|
def test_get_all_rank_n(self):
|
|
store = mock.MagicMock()
|
|
world_size = 3
|
|
store_util.get_all(store, 1, "test/store", world_size)
|
|
# omit empty kwargs, get only key
|
|
actual_set_call_args = [
|
|
call_args[0][0] for call_args in store.set.call_args_list
|
|
]
|
|
self.assertListEqual(["test/store1.FIN"], actual_set_call_args)
|
|
|
|
actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list]
|
|
expected_get_call_args = [
|
|
("test/store0",),
|
|
("test/store1",),
|
|
("test/store2",),
|
|
]
|
|
self.assertListEqual(expected_get_call_args, actual_get_call_args)
|
|
|
|
def test_synchronize(self):
|
|
store_mock = mock.MagicMock()
|
|
data = "data0".encode(encoding="UTF-8")
|
|
store_util.synchronize(store_mock, data, 0, 3, key_prefix="torchelastic/test")
|
|
actual_set_call_args = store_mock.set.call_args_list
|
|
# omit empty kwargs
|
|
actual_set_call_args = [call_args[0] for call_args in actual_set_call_args]
|
|
expected_set_call_args = [
|
|
("torchelastic/test0", b"data0"),
|
|
("torchelastic/test0.FIN", b"FIN"),
|
|
]
|
|
self.assertListEqual(expected_set_call_args, actual_set_call_args)
|
|
|
|
expected_get_call_args = [
|
|
("torchelastic/test0",),
|
|
("torchelastic/test1",),
|
|
("torchelastic/test2",),
|
|
("torchelastic/test0.FIN",),
|
|
("torchelastic/test1.FIN",),
|
|
("torchelastic/test2.FIN",),
|
|
]
|
|
actual_get_call_args = store_mock.get.call_args_list
|
|
actual_get_call_args = [call_args[0] for call_args in actual_get_call_args]
|
|
self.assertListEqual(expected_get_call_args, actual_get_call_args)
|
|
|
|
|
|
class UtilTest(TestCase):
|
|
def test_get_logger_different(self):
|
|
logger1 = get_logger("name1")
|
|
logger2 = get_logger("name2")
|
|
self.assertNotEqual(logger1.name, logger2.name)
|
|
|
|
def test_get_logger(self):
|
|
logger1 = get_logger()
|
|
self.assertEqual(__name__, logger1.name)
|
|
|
|
def test_get_logger_none(self):
|
|
logger1 = get_logger(None)
|
|
self.assertEqual(__name__, logger1.name)
|
|
|
|
def test_get_logger_custom_name(self):
|
|
logger1 = get_logger("test.module")
|
|
self.assertEqual("test.module", logger1.name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|