mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D27572158: [torchelastic] Make sure torchelastic mp wait for queue to be drained before finishing the process
Test Plan: revert-hammer
Differential Revision:
D27572158 (e9c6a51100)
Original commit changeset: 9a360468acc9
fbshipit-source-id: 29f7e2cba3e134bc81fb31b7e1dfceb7c1f9d734
This commit is contained in:
parent
8e78a1b084
commit
ae3a876c9c
|
|
@ -71,13 +71,6 @@ def _sad_function():
|
||||||
raise RuntimeError("sad because i throw")
|
raise RuntimeError("sad because i throw")
|
||||||
|
|
||||||
|
|
||||||
def dummy_compute() -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
returns a predefined size random Tensor
|
|
||||||
"""
|
|
||||||
return torch.rand(100, 100)
|
|
||||||
|
|
||||||
|
|
||||||
def _fatal_signal_function(expected_error_index: int, sig: int):
|
def _fatal_signal_function(expected_error_index: int, sig: int):
|
||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
if rank == expected_error_index:
|
if rank == expected_error_index:
|
||||||
|
|
@ -313,16 +306,6 @@ class LocalElasticAgentTest(unittest.TestCase):
|
||||||
results.setdefault(role, []).append(run_result)
|
results.setdefault(role, []).append(run_result)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
|
|
||||||
)
|
|
||||||
def test_dummy_compute(self):
|
|
||||||
res = self.run_agent(Conf(entrypoint=dummy_compute, local_world_size=2))
|
|
||||||
self.assertFalse(res.is_failed())
|
|
||||||
for return_value in res.return_values.values():
|
|
||||||
self.assertIsInstance(return_value, torch.Tensor)
|
|
||||||
self.assertEqual((100, 100), return_value.shape)
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
|
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ from itertools import product
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
|
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
|
||||||
from torch.distributed.elastic.multiprocessing.api import (
|
from torch.distributed.elastic.multiprocessing.api import (
|
||||||
|
|
@ -145,13 +144,6 @@ def echo_large(size: int) -> Dict[int, str]:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def dummy_compute() -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
returns a predefined size random Tensor
|
|
||||||
"""
|
|
||||||
return torch.rand(100, 100)
|
|
||||||
|
|
||||||
|
|
||||||
def redirects() -> List[Std]:
|
def redirects() -> List[Std]:
|
||||||
return [
|
return [
|
||||||
Std.NONE,
|
Std.NONE,
|
||||||
|
|
@ -213,7 +205,6 @@ class StartProcessesTest(unittest.TestCase):
|
||||||
|
|
||||||
for stdout_redir, stderr_redir in redirs:
|
for stdout_redir, stderr_redir in redirs:
|
||||||
queue = multiprocessing.SimpleQueue()
|
queue = multiprocessing.SimpleQueue()
|
||||||
worker_finished_event_mock = mock.Mock()
|
|
||||||
_wrap(
|
_wrap(
|
||||||
local_rank=0,
|
local_rank=0,
|
||||||
fn=echo1,
|
fn=echo1,
|
||||||
|
|
@ -222,14 +213,12 @@ class StartProcessesTest(unittest.TestCase):
|
||||||
stdout_redirects={0: stdout_redir},
|
stdout_redirects={0: stdout_redir},
|
||||||
stderr_redirects={0: stderr_redir},
|
stderr_redirects={0: stderr_redir},
|
||||||
ret_vals={0: queue},
|
ret_vals={0: queue},
|
||||||
queue_finished_reading_event=worker_finished_event_mock,
|
|
||||||
)
|
)
|
||||||
self.assertEqual("hello_0", queue.get())
|
self.assertEqual("hello_0", queue.get())
|
||||||
if stdout_redir:
|
if stdout_redir:
|
||||||
self.assert_in_file(["hello stdout from 0"], stdout_log)
|
self.assert_in_file(["hello stdout from 0"], stdout_log)
|
||||||
if stderr_redir:
|
if stderr_redir:
|
||||||
self.assert_in_file(["hello stderr from 0"], stderr_log)
|
self.assert_in_file(["hello stderr from 0"], stderr_log)
|
||||||
worker_finished_event_mock.wait.assert_called_once()
|
|
||||||
|
|
||||||
def test_invalid_log_dir(self):
|
def test_invalid_log_dir(self):
|
||||||
with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir:
|
with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir:
|
||||||
|
|
@ -350,26 +339,6 @@ class StartProcessesTest(unittest.TestCase):
|
||||||
[f"hello stderr from {i}"], results.stderrs[i]
|
[f"hello stderr from {i}"], results.stderrs[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
|
|
||||||
)
|
|
||||||
def test_function_with_tensor(self):
|
|
||||||
for start_method in self._start_methods:
|
|
||||||
pc = start_processes(
|
|
||||||
name="dummy_compute",
|
|
||||||
entrypoint=dummy_compute,
|
|
||||||
args={},
|
|
||||||
envs={},
|
|
||||||
log_dir=self.log_dir(),
|
|
||||||
start_method=start_method,
|
|
||||||
)
|
|
||||||
|
|
||||||
results = pc.wait()
|
|
||||||
self.assert_pids_noexist(pc.pids())
|
|
||||||
for return_value in results.return_values.values():
|
|
||||||
self.assertIsInstance(return_value, torch.Tensor)
|
|
||||||
self.assertEqual((100, 100), return_value.shape)
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
|
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -271,7 +271,6 @@ def _wrap(
|
||||||
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
|
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
|
||||||
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
|
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
|
||||||
ret_vals: Dict[int, mp.SimpleQueue],
|
ret_vals: Dict[int, mp.SimpleQueue],
|
||||||
queue_finished_reading_event: mp.Event,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
# get the per-rank params up front so we fail fast if no mapping is found
|
# get the per-rank params up front so we fail fast if no mapping is found
|
||||||
args_ = args[local_rank]
|
args_ = args[local_rank]
|
||||||
|
|
@ -290,7 +289,6 @@ def _wrap(
|
||||||
with stdout_cm, stderr_cm:
|
with stdout_cm, stderr_cm:
|
||||||
ret = record(fn)(*args_)
|
ret = record(fn)(*args_)
|
||||||
ret_val_.put(ret)
|
ret_val_.put(ret)
|
||||||
queue_finished_reading_event.wait()
|
|
||||||
|
|
||||||
|
|
||||||
class MultiprocessContext(PContext):
|
class MultiprocessContext(PContext):
|
||||||
|
|
@ -333,9 +331,6 @@ class MultiprocessContext(PContext):
|
||||||
# see comments in ``join()`` for what this is
|
# see comments in ``join()`` for what this is
|
||||||
self._return_values: Dict[int, Any] = {}
|
self._return_values: Dict[int, Any] = {}
|
||||||
self._pc: Optional[mp.ProcessContext] = None
|
self._pc: Optional[mp.ProcessContext] = None
|
||||||
# Note: set method should ONLY be invoked for the use case when all processes finished
|
|
||||||
# successfully. If any process died on event.wait() calling set() method will deadlock.
|
|
||||||
self._worker_finished_event = mp.get_context(self.start_method).Event()
|
|
||||||
|
|
||||||
def _start(self):
|
def _start(self):
|
||||||
if self._pc:
|
if self._pc:
|
||||||
|
|
@ -352,7 +347,6 @@ class MultiprocessContext(PContext):
|
||||||
self.stdouts,
|
self.stdouts,
|
||||||
self.stderrs,
|
self.stderrs,
|
||||||
self._ret_vals,
|
self._ret_vals,
|
||||||
self._worker_finished_event,
|
|
||||||
),
|
),
|
||||||
nprocs=self.nprocs,
|
nprocs=self.nprocs,
|
||||||
join=False,
|
join=False,
|
||||||
|
|
@ -360,19 +354,15 @@ class MultiprocessContext(PContext):
|
||||||
start_method=self.start_method,
|
start_method=self.start_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _is_done(self) -> bool:
|
|
||||||
return len(self._return_values) == self.nprocs
|
|
||||||
|
|
||||||
def _poll(self) -> Optional[RunProcsResult]:
|
def _poll(self) -> Optional[RunProcsResult]:
|
||||||
assert self._pc is not None # assertion for mypy type checker
|
assert self._pc is not None # assertion for mypy type checker
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# torch.mp.ProcessContext Throws an Exception if some/all of
|
# torch.mp.ProcessContext returns True if all the workers have
|
||||||
# worker processes failed
|
# successfully finished, False if some/all are still running
|
||||||
|
# and throws an Exception if some/all of them failed
|
||||||
# timeout < 0 checks worker status and return immediately
|
# timeout < 0 checks worker status and return immediately
|
||||||
# Join will never return success since we use mp.Event to wait
|
done = self._pc.join(-1)
|
||||||
# for all processes to finish.
|
|
||||||
self._pc.join(-1)
|
|
||||||
|
|
||||||
# IMPORTANT: we use multiprocessing.Queue to carry worker return values
|
# IMPORTANT: we use multiprocessing.Queue to carry worker return values
|
||||||
# back to the parent, the worker process will wait before terminating
|
# back to the parent, the worker process will wait before terminating
|
||||||
|
|
@ -386,12 +376,8 @@ class MultiprocessContext(PContext):
|
||||||
# save the return values temporarily into a member var
|
# save the return values temporarily into a member var
|
||||||
self._return_values[local_rank] = return_queue.get()
|
self._return_values[local_rank] = return_queue.get()
|
||||||
|
|
||||||
if self._is_done():
|
if done:
|
||||||
# we should ALWAYS have ALL the return values when all the processes are done
|
# we should ALWAYS have ALL the return values when all the processes are done
|
||||||
self._worker_finished_event.set()
|
|
||||||
# Wait untill all processes are finished. At this point workers finished executing
|
|
||||||
# user function
|
|
||||||
self._pc.join()
|
|
||||||
_validate_full_rank(
|
_validate_full_rank(
|
||||||
self._return_values, self.nprocs, "return_value queue"
|
self._return_values, self.nprocs, "return_value queue"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user