Revert D27572158: [torchelastic] Make sure torchelastic mp wait for queue to be drained before finishing the process

Test Plan: revert-hammer

Differential Revision:
D27572158 (e9c6a51100)

Original commit changeset: 9a360468acc9

fbshipit-source-id: 29f7e2cba3e134bc81fb31b7e1dfceb7c1f9d734
This commit is contained in:
Brian Hirsh 2021-04-06 11:40:34 -07:00 committed by Facebook GitHub Bot
parent 8e78a1b084
commit ae3a876c9c
3 changed files with 5 additions and 67 deletions

View File

@ -71,13 +71,6 @@ def _sad_function():
raise RuntimeError("sad because i throw") raise RuntimeError("sad because i throw")
def dummy_compute() -> torch.Tensor:
"""
returns a predefined size random Tensor
"""
return torch.rand(100, 100)
def _fatal_signal_function(expected_error_index: int, sig: int): def _fatal_signal_function(expected_error_index: int, sig: int):
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
if rank == expected_error_index: if rank == expected_error_index:
@ -313,16 +306,6 @@ class LocalElasticAgentTest(unittest.TestCase):
results.setdefault(role, []).append(run_result) results.setdefault(role, []).append(run_result)
return results return results
@unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
)
def test_dummy_compute(self):
res = self.run_agent(Conf(entrypoint=dummy_compute, local_world_size=2))
self.assertFalse(res.is_failed())
for return_value in res.return_values.values():
self.assertIsInstance(return_value, torch.Tensor)
self.assertEqual((100, 100), return_value.shape)
@unittest.skipIf( @unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
) )

View File

@ -18,7 +18,6 @@ from itertools import product
from typing import Dict, List from typing import Dict, List
from unittest import mock from unittest import mock
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
from torch.distributed.elastic.multiprocessing.api import ( from torch.distributed.elastic.multiprocessing.api import (
@ -145,13 +144,6 @@ def echo_large(size: int) -> Dict[int, str]:
return out return out
def dummy_compute() -> torch.Tensor:
"""
returns a predefined size random Tensor
"""
return torch.rand(100, 100)
def redirects() -> List[Std]: def redirects() -> List[Std]:
return [ return [
Std.NONE, Std.NONE,
@ -213,7 +205,6 @@ class StartProcessesTest(unittest.TestCase):
for stdout_redir, stderr_redir in redirs: for stdout_redir, stderr_redir in redirs:
queue = multiprocessing.SimpleQueue() queue = multiprocessing.SimpleQueue()
worker_finished_event_mock = mock.Mock()
_wrap( _wrap(
local_rank=0, local_rank=0,
fn=echo1, fn=echo1,
@ -222,14 +213,12 @@ class StartProcessesTest(unittest.TestCase):
stdout_redirects={0: stdout_redir}, stdout_redirects={0: stdout_redir},
stderr_redirects={0: stderr_redir}, stderr_redirects={0: stderr_redir},
ret_vals={0: queue}, ret_vals={0: queue},
queue_finished_reading_event=worker_finished_event_mock,
) )
self.assertEqual("hello_0", queue.get()) self.assertEqual("hello_0", queue.get())
if stdout_redir: if stdout_redir:
self.assert_in_file(["hello stdout from 0"], stdout_log) self.assert_in_file(["hello stdout from 0"], stdout_log)
if stderr_redir: if stderr_redir:
self.assert_in_file(["hello stderr from 0"], stderr_log) self.assert_in_file(["hello stderr from 0"], stderr_log)
worker_finished_event_mock.wait.assert_called_once()
def test_invalid_log_dir(self): def test_invalid_log_dir(self):
with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir: with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir:
@ -350,26 +339,6 @@ class StartProcessesTest(unittest.TestCase):
[f"hello stderr from {i}"], results.stderrs[i] [f"hello stderr from {i}"], results.stderrs[i]
) )
@unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
)
def test_function_with_tensor(self):
for start_method in self._start_methods:
pc = start_processes(
name="dummy_compute",
entrypoint=dummy_compute,
args={},
envs={},
log_dir=self.log_dir(),
start_method=start_method,
)
results = pc.wait()
self.assert_pids_noexist(pc.pids())
for return_value in results.return_values.values():
self.assertIsInstance(return_value, torch.Tensor)
self.assertEqual((100, 100), return_value.shape)
@unittest.skipIf( @unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
) )

View File

@ -271,7 +271,6 @@ def _wrap(
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None) stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None) stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
ret_vals: Dict[int, mp.SimpleQueue], ret_vals: Dict[int, mp.SimpleQueue],
queue_finished_reading_event: mp.Event,
) -> None: ) -> None:
# get the per-rank params up front so we fail fast if no mapping is found # get the per-rank params up front so we fail fast if no mapping is found
args_ = args[local_rank] args_ = args[local_rank]
@ -290,7 +289,6 @@ def _wrap(
with stdout_cm, stderr_cm: with stdout_cm, stderr_cm:
ret = record(fn)(*args_) ret = record(fn)(*args_)
ret_val_.put(ret) ret_val_.put(ret)
queue_finished_reading_event.wait()
class MultiprocessContext(PContext): class MultiprocessContext(PContext):
@ -333,9 +331,6 @@ class MultiprocessContext(PContext):
# see comments in ``join()`` for what this is # see comments in ``join()`` for what this is
self._return_values: Dict[int, Any] = {} self._return_values: Dict[int, Any] = {}
self._pc: Optional[mp.ProcessContext] = None self._pc: Optional[mp.ProcessContext] = None
# Note: set method should ONLY be invoked for the use case when all processes finished
# successfully. If any process died on event.wait() calling set() method will deadlock.
self._worker_finished_event = mp.get_context(self.start_method).Event()
def _start(self): def _start(self):
if self._pc: if self._pc:
@ -352,7 +347,6 @@ class MultiprocessContext(PContext):
self.stdouts, self.stdouts,
self.stderrs, self.stderrs,
self._ret_vals, self._ret_vals,
self._worker_finished_event,
), ),
nprocs=self.nprocs, nprocs=self.nprocs,
join=False, join=False,
@ -360,19 +354,15 @@ class MultiprocessContext(PContext):
start_method=self.start_method, start_method=self.start_method,
) )
def _is_done(self) -> bool:
return len(self._return_values) == self.nprocs
def _poll(self) -> Optional[RunProcsResult]: def _poll(self) -> Optional[RunProcsResult]:
assert self._pc is not None # assertion for mypy type checker assert self._pc is not None # assertion for mypy type checker
try: try:
# torch.mp.ProcessContext Throws an Exception if some/all of # torch.mp.ProcessContext returns True if all the workers have
# worker processes failed # successfully finished, False if some/all are still running
# and throws an Exception if some/all of them failed
# timeout < 0 checks worker status and return immediately # timeout < 0 checks worker status and return immediately
# Join will never return success since we use mp.Event to wait done = self._pc.join(-1)
# for all processes to finish.
self._pc.join(-1)
# IMPORTANT: we use multiprocessing.Queue to carry worker return values # IMPORTANT: we use multiprocessing.Queue to carry worker return values
# back to the parent, the worker process will wait before terminating # back to the parent, the worker process will wait before terminating
@ -386,12 +376,8 @@ class MultiprocessContext(PContext):
# save the return values temporarily into a member var # save the return values temporarily into a member var
self._return_values[local_rank] = return_queue.get() self._return_values[local_rank] = return_queue.get()
if self._is_done(): if done:
# we should ALWAYS have ALL the return values when all the processes are done # we should ALWAYS have ALL the return values when all the processes are done
self._worker_finished_event.set()
# Wait untill all processes are finished. At this point workers finished executing
# user function
self._pc.join()
_validate_full_rank( _validate_full_rank(
self._return_values, self.nprocs, "return_value queue" self._return_values, self.nprocs, "return_value queue"
) )