mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D27572158: [torchelastic] Make sure torchelastic mp wait for queue to be drained before finishing the process
Test Plan: revert-hammer
Differential Revision:
D27572158 (e9c6a51100)
Original commit changeset: 9a360468acc9
fbshipit-source-id: 29f7e2cba3e134bc81fb31b7e1dfceb7c1f9d734
This commit is contained in:
parent
8e78a1b084
commit
ae3a876c9c
|
|
@ -71,13 +71,6 @@ def _sad_function():
|
|||
raise RuntimeError("sad because i throw")
|
||||
|
||||
|
||||
def dummy_compute() -> torch.Tensor:
|
||||
"""
|
||||
returns a predefined size random Tensor
|
||||
"""
|
||||
return torch.rand(100, 100)
|
||||
|
||||
|
||||
def _fatal_signal_function(expected_error_index: int, sig: int):
|
||||
rank = int(os.environ["RANK"])
|
||||
if rank == expected_error_index:
|
||||
|
|
@ -313,16 +306,6 @@ class LocalElasticAgentTest(unittest.TestCase):
|
|||
results.setdefault(role, []).append(run_result)
|
||||
return results
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
|
||||
)
|
||||
def test_dummy_compute(self):
|
||||
res = self.run_agent(Conf(entrypoint=dummy_compute, local_world_size=2))
|
||||
self.assertFalse(res.is_failed())
|
||||
for return_value in res.return_values.values():
|
||||
self.assertIsInstance(return_value, torch.Tensor)
|
||||
self.assertEqual((100, 100), return_value.shape)
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from itertools import product
|
|||
from typing import Dict, List
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
|
||||
from torch.distributed.elastic.multiprocessing.api import (
|
||||
|
|
@ -145,13 +144,6 @@ def echo_large(size: int) -> Dict[int, str]:
|
|||
return out
|
||||
|
||||
|
||||
def dummy_compute() -> torch.Tensor:
|
||||
"""
|
||||
returns a predefined size random Tensor
|
||||
"""
|
||||
return torch.rand(100, 100)
|
||||
|
||||
|
||||
def redirects() -> List[Std]:
|
||||
return [
|
||||
Std.NONE,
|
||||
|
|
@ -213,7 +205,6 @@ class StartProcessesTest(unittest.TestCase):
|
|||
|
||||
for stdout_redir, stderr_redir in redirs:
|
||||
queue = multiprocessing.SimpleQueue()
|
||||
worker_finished_event_mock = mock.Mock()
|
||||
_wrap(
|
||||
local_rank=0,
|
||||
fn=echo1,
|
||||
|
|
@ -222,14 +213,12 @@ class StartProcessesTest(unittest.TestCase):
|
|||
stdout_redirects={0: stdout_redir},
|
||||
stderr_redirects={0: stderr_redir},
|
||||
ret_vals={0: queue},
|
||||
queue_finished_reading_event=worker_finished_event_mock,
|
||||
)
|
||||
self.assertEqual("hello_0", queue.get())
|
||||
if stdout_redir:
|
||||
self.assert_in_file(["hello stdout from 0"], stdout_log)
|
||||
if stderr_redir:
|
||||
self.assert_in_file(["hello stderr from 0"], stderr_log)
|
||||
worker_finished_event_mock.wait.assert_called_once()
|
||||
|
||||
def test_invalid_log_dir(self):
|
||||
with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir:
|
||||
|
|
@ -350,26 +339,6 @@ class StartProcessesTest(unittest.TestCase):
|
|||
[f"hello stderr from {i}"], results.stderrs[i]
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
|
||||
)
|
||||
def test_function_with_tensor(self):
|
||||
for start_method in self._start_methods:
|
||||
pc = start_processes(
|
||||
name="dummy_compute",
|
||||
entrypoint=dummy_compute,
|
||||
args={},
|
||||
envs={},
|
||||
log_dir=self.log_dir(),
|
||||
start_method=start_method,
|
||||
)
|
||||
|
||||
results = pc.wait()
|
||||
self.assert_pids_noexist(pc.pids())
|
||||
for return_value in results.return_values.values():
|
||||
self.assertIsInstance(return_value, torch.Tensor)
|
||||
self.assertEqual((100, 100), return_value.shape)
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -271,7 +271,6 @@ def _wrap(
|
|||
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
|
||||
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
|
||||
ret_vals: Dict[int, mp.SimpleQueue],
|
||||
queue_finished_reading_event: mp.Event,
|
||||
) -> None:
|
||||
# get the per-rank params up front so we fail fast if no mapping is found
|
||||
args_ = args[local_rank]
|
||||
|
|
@ -290,7 +289,6 @@ def _wrap(
|
|||
with stdout_cm, stderr_cm:
|
||||
ret = record(fn)(*args_)
|
||||
ret_val_.put(ret)
|
||||
queue_finished_reading_event.wait()
|
||||
|
||||
|
||||
class MultiprocessContext(PContext):
|
||||
|
|
@ -333,9 +331,6 @@ class MultiprocessContext(PContext):
|
|||
# see comments in ``join()`` for what this is
|
||||
self._return_values: Dict[int, Any] = {}
|
||||
self._pc: Optional[mp.ProcessContext] = None
|
||||
# Note: set method should ONLY be invoked for the use case when all processes finished
|
||||
# successfully. If any process died on event.wait() calling set() method will deadlock.
|
||||
self._worker_finished_event = mp.get_context(self.start_method).Event()
|
||||
|
||||
def _start(self):
|
||||
if self._pc:
|
||||
|
|
@ -352,7 +347,6 @@ class MultiprocessContext(PContext):
|
|||
self.stdouts,
|
||||
self.stderrs,
|
||||
self._ret_vals,
|
||||
self._worker_finished_event,
|
||||
),
|
||||
nprocs=self.nprocs,
|
||||
join=False,
|
||||
|
|
@ -360,19 +354,15 @@ class MultiprocessContext(PContext):
|
|||
start_method=self.start_method,
|
||||
)
|
||||
|
||||
def _is_done(self) -> bool:
|
||||
return len(self._return_values) == self.nprocs
|
||||
|
||||
def _poll(self) -> Optional[RunProcsResult]:
|
||||
assert self._pc is not None # assertion for mypy type checker
|
||||
|
||||
try:
|
||||
# torch.mp.ProcessContext Throws an Exception if some/all of
|
||||
# worker processes failed
|
||||
# torch.mp.ProcessContext returns True if all the workers have
|
||||
# successfully finished, False if some/all are still running
|
||||
# and throws an Exception if some/all of them failed
|
||||
# timeout < 0 checks worker status and return immediately
|
||||
# Join will never return success since we use mp.Event to wait
|
||||
# for all processes to finish.
|
||||
self._pc.join(-1)
|
||||
done = self._pc.join(-1)
|
||||
|
||||
# IMPORTANT: we use multiprocessing.Queue to carry worker return values
|
||||
# back to the parent, the worker process will wait before terminating
|
||||
|
|
@ -386,12 +376,8 @@ class MultiprocessContext(PContext):
|
|||
# save the return values temporarily into a member var
|
||||
self._return_values[local_rank] = return_queue.get()
|
||||
|
||||
if self._is_done():
|
||||
if done:
|
||||
# we should ALWAYS have ALL the return values when all the processes are done
|
||||
self._worker_finished_event.set()
|
||||
# Wait untill all processes are finished. At this point workers finished executing
|
||||
# user function
|
||||
self._pc.join()
|
||||
_validate_full_rank(
|
||||
self._return_values, self.nprocs, "return_value queue"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user