mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[c10d] Fix the hang issue in store.check(TIMEOUT_DUMP) (#116297)
Summary: We have found out the root cause of the hang is NOT due to destruction of stores. The hang in the check() only happens when the store is of type FileStore. The file held by each filestore was a temp file, which was created by Python Tempfile, it was deleted by default when the file was closed. Note that the file was opened and closed by every check() in the watchdog and in constructor of FileStore. The when check() tried to open the deleted file again, open() would fail after the timeout value (by default 5 mins), hence the hang happened. The fix is simple, just avoid the default deletion after the file is closed. Test Plan: 1. We first reproduce the hang in check() in the existing unit test: test_init_process_group_for_all_backends by enabling the DumpOnTimeOut and making the main thread sleep for 2s, to give enough time for tempfile to be deleted 2. Adding log to check ref count of fileStore and also the sequence of file opening and closing 3. With the repro, an exception will be thrown as "no such file or directory' and unit test would fail 4. Verify the tests now passes with the above knob change 5. add an unit test in test_c10d_nccl to cover the fileStore check() code path python test/distributed/test_c10d_common.py ProcessGroupWithDispatchedCollectivesTests python test/distributed/test_c10d_nccl.py ProcessGroupNCCLTest.test_file_store_check Reviewers: Subscribers: Tasks: T173200093 Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/116297 Approved by: https://github.com/fduwjj ghstack dependencies: #116296
This commit is contained in:
parent
94f3781145
commit
b47aa69685
|
|
@ -133,7 +133,7 @@ class AbstractTimeoutTest:
|
|||
class TimeoutTest(TestCase):
|
||||
@retry_on_connect_failures
|
||||
def test_store_based_barrier(self):
|
||||
f = tempfile.NamedTemporaryFile()
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
port = common.find_free_port()
|
||||
|
||||
def thread_work(timeout, init_type, world_size, rank, error_list):
|
||||
|
|
@ -1756,7 +1756,7 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
|
|||
pass
|
||||
|
||||
def test_init_process_group_optional_backend(self):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
store = dist.FileStore(f.name, self.world_size)
|
||||
# creates both gloo and nccl backend
|
||||
if dist.is_gloo_available() and dist.is_nccl_available():
|
||||
|
|
@ -1785,7 +1785,7 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
|
|||
if not dist.is_ucc_available():
|
||||
continue
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
store = dist.FileStore(f.name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
|
|
|
|||
|
|
@ -1197,6 +1197,31 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
|
|||
with self.assertRaises(dist.DistBackendError):
|
||||
pg.allreduce([t])
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
def test_file_store_check(self):
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
os.environ["TORCH_NCCL_ENABLE_MONITORING"] = "0"
|
||||
# FileStore check() would be executed
|
||||
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
|
||||
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "0"
|
||||
|
||||
# self.file_name is created using "delete=False"
|
||||
# e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=store
|
||||
)
|
||||
pg = dist.distributed_c10d._get_default_group()
|
||||
self.assertEqual(pg.rank(), self.rank)
|
||||
self.assertEqual(pg.size(), self.world_size)
|
||||
# give enough time for check() to be executed multiple times
|
||||
time.sleep(2)
|
||||
dist.destroy_process_group()
|
||||
|
||||
def _check_nccl_timeout(self, expected_timeout):
|
||||
pg = dist.distributed_c10d._get_default_group()
|
||||
options = pg._get_backend(torch.device(f"cuda:{self.rank}")).options
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user