Revert D27572158: [torchelastic] Make sure torchelastic mp wait for queue to be drained before finishing the process

Test Plan: revert-hammer

Differential Revision:
D27572158 (e9c6a51100)

Original commit changeset: 9a360468acc9

fbshipit-source-id: 29f7e2cba3e134bc81fb31b7e1dfceb7c1f9d734
This commit is contained in:
Brian Hirsh 2021-04-06 11:40:34 -07:00 committed by Facebook GitHub Bot
parent 8e78a1b084
commit ae3a876c9c
3 changed files with 5 additions and 67 deletions

View File

@ -71,13 +71,6 @@ def _sad_function():
raise RuntimeError("sad because i throw")
def dummy_compute() -> torch.Tensor:
"""
returns a predefined size random Tensor
"""
return torch.rand(100, 100)
def _fatal_signal_function(expected_error_index: int, sig: int):
rank = int(os.environ["RANK"])
if rank == expected_error_index:
@ -313,16 +306,6 @@ class LocalElasticAgentTest(unittest.TestCase):
results.setdefault(role, []).append(run_result)
return results
@unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
)
def test_dummy_compute(self):
res = self.run_agent(Conf(entrypoint=dummy_compute, local_world_size=2))
self.assertFalse(res.is_failed())
for return_value in res.return_values.values():
self.assertIsInstance(return_value, torch.Tensor)
self.assertEqual((100, 100), return_value.shape)
@unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
)

View File

@ -18,7 +18,6 @@ from itertools import product
from typing import Dict, List
from unittest import mock
import torch
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
from torch.distributed.elastic.multiprocessing.api import (
@ -145,13 +144,6 @@ def echo_large(size: int) -> Dict[int, str]:
return out
def dummy_compute() -> torch.Tensor:
"""
returns a predefined size random Tensor
"""
return torch.rand(100, 100)
def redirects() -> List[Std]:
return [
Std.NONE,
@ -213,7 +205,6 @@ class StartProcessesTest(unittest.TestCase):
for stdout_redir, stderr_redir in redirs:
queue = multiprocessing.SimpleQueue()
worker_finished_event_mock = mock.Mock()
_wrap(
local_rank=0,
fn=echo1,
@ -222,14 +213,12 @@ class StartProcessesTest(unittest.TestCase):
stdout_redirects={0: stdout_redir},
stderr_redirects={0: stderr_redir},
ret_vals={0: queue},
queue_finished_reading_event=worker_finished_event_mock,
)
self.assertEqual("hello_0", queue.get())
if stdout_redir:
self.assert_in_file(["hello stdout from 0"], stdout_log)
if stderr_redir:
self.assert_in_file(["hello stderr from 0"], stderr_log)
worker_finished_event_mock.wait.assert_called_once()
def test_invalid_log_dir(self):
with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir:
@ -350,26 +339,6 @@ class StartProcessesTest(unittest.TestCase):
[f"hello stderr from {i}"], results.stderrs[i]
)
@unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
)
def test_function_with_tensor(self):
for start_method in self._start_methods:
pc = start_processes(
name="dummy_compute",
entrypoint=dummy_compute,
args={},
envs={},
log_dir=self.log_dir(),
start_method=start_method,
)
results = pc.wait()
self.assert_pids_noexist(pc.pids())
for return_value in results.return_values.values():
self.assertIsInstance(return_value, torch.Tensor)
self.assertEqual((100, 100), return_value.shape)
@unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
)

View File

@ -271,7 +271,6 @@ def _wrap(
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
ret_vals: Dict[int, mp.SimpleQueue],
queue_finished_reading_event: mp.Event,
) -> None:
# get the per-rank params up front so we fail fast if no mapping is found
args_ = args[local_rank]
@ -290,7 +289,6 @@ def _wrap(
with stdout_cm, stderr_cm:
ret = record(fn)(*args_)
ret_val_.put(ret)
queue_finished_reading_event.wait()
class MultiprocessContext(PContext):
@ -333,9 +331,6 @@ class MultiprocessContext(PContext):
# see comments in ``join()`` for what this is
self._return_values: Dict[int, Any] = {}
self._pc: Optional[mp.ProcessContext] = None
# Note: set method should ONLY be invoked for the use case when all processes finished
# successfully. If any process died on event.wait() calling set() method will deadlock.
self._worker_finished_event = mp.get_context(self.start_method).Event()
def _start(self):
if self._pc:
@ -352,7 +347,6 @@ class MultiprocessContext(PContext):
self.stdouts,
self.stderrs,
self._ret_vals,
self._worker_finished_event,
),
nprocs=self.nprocs,
join=False,
@ -360,19 +354,15 @@ class MultiprocessContext(PContext):
start_method=self.start_method,
)
def _is_done(self) -> bool:
return len(self._return_values) == self.nprocs
def _poll(self) -> Optional[RunProcsResult]:
assert self._pc is not None # assertion for mypy type checker
try:
# torch.mp.ProcessContext Throws an Exception if some/all of
# worker processes failed
# torch.mp.ProcessContext returns True if all the workers have
# successfully finished, False if some/all are still running
# and throws an Exception if some/all of them failed
# timeout < 0 checks worker status and return immediately
# Join will never return success since we use mp.Event to wait
# for all processes to finish.
self._pc.join(-1)
done = self._pc.join(-1)
# IMPORTANT: we use multiprocessing.Queue to carry worker return values
# back to the parent, the worker process will wait before terminating
@ -386,12 +376,8 @@ class MultiprocessContext(PContext):
# save the return values temporarily into a member var
self._return_values[local_rank] = return_queue.get()
if self._is_done():
if done:
# we should ALWAYS have ALL the return values when all the processes are done
self._worker_finished_event.set()
# Wait untill all processes are finished. At this point workers finished executing
# user function
self._pc.join()
_validate_full_rank(
self._return_values, self.nprocs, "return_value queue"
)