diff --git a/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py b/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py index caa6bc73ae2..f272180008c 100644 --- a/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py +++ b/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py @@ -71,13 +71,6 @@ def _sad_function(): raise RuntimeError("sad because i throw") -def dummy_compute() -> torch.Tensor: - """ - returns a predefined size random Tensor - """ - return torch.rand(100, 100) - - def _fatal_signal_function(expected_error_index: int, sig: int): rank = int(os.environ["RANK"]) if rank == expected_error_index: @@ -313,16 +306,6 @@ class LocalElasticAgentTest(unittest.TestCase): results.setdefault(role, []).append(run_result) return results - @unittest.skipIf( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" - ) - def test_dummy_compute(self): - res = self.run_agent(Conf(entrypoint=dummy_compute, local_world_size=2)) - self.assertFalse(res.is_failed()) - for return_value in res.return_values.values(): - self.assertIsInstance(return_value, torch.Tensor) - self.assertEqual((100, 100), return_value.shape) - @unittest.skipIf( TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" ) diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index 731d9e2ce76..ab65b98e1d0 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -18,7 +18,6 @@ from itertools import product from typing import Dict, List from unittest import mock -import torch import torch.multiprocessing as mp from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes from torch.distributed.elastic.multiprocessing.api import ( @@ -145,13 +144,6 @@ def echo_large(size: int) -> Dict[int, str]: return out -def dummy_compute() -> torch.Tensor: - """ - returns a predefined size random Tensor - """ - return torch.rand(100, 100) - - def redirects() -> List[Std]: return [ Std.NONE, @@ -213,7 +205,6 @@ class StartProcessesTest(unittest.TestCase): for stdout_redir, stderr_redir in redirs: queue = multiprocessing.SimpleQueue() - worker_finished_event_mock = mock.Mock() _wrap( local_rank=0, fn=echo1, @@ -222,14 +213,12 @@ class StartProcessesTest(unittest.TestCase): stdout_redirects={0: stdout_redir}, stderr_redirects={0: stderr_redir}, ret_vals={0: queue}, - queue_finished_reading_event=worker_finished_event_mock, ) self.assertEqual("hello_0", queue.get()) if stdout_redir: self.assert_in_file(["hello stdout from 0"], stdout_log) if stderr_redir: self.assert_in_file(["hello stderr from 0"], stderr_log) - worker_finished_event_mock.wait.assert_called_once() def test_invalid_log_dir(self): with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir: @@ -350,26 +339,6 @@ class StartProcessesTest(unittest.TestCase): [f"hello stderr from {i}"], results.stderrs[i] ) - @unittest.skipIf( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" - ) - def test_function_with_tensor(self): - for start_method in self._start_methods: - pc = start_processes( - name="dummy_compute", - entrypoint=dummy_compute, - args={}, - envs={}, - log_dir=self.log_dir(), - start_method=start_method, - ) - - results = pc.wait() - self.assert_pids_noexist(pc.pids()) - for return_value in results.return_values.values(): - self.assertIsInstance(return_value, torch.Tensor) - self.assertEqual((100, 100), return_value.shape) - @unittest.skipIf( TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" ) diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 9205644ff01..196e6bb3bcb 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -271,7 +271,6 @@ def _wrap( stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None) stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None) ret_vals: Dict[int, mp.SimpleQueue], - queue_finished_reading_event: mp.Event, ) -> None: # get the per-rank params up front so we fail fast if no mapping is found args_ = args[local_rank] @@ -290,7 +289,6 @@ def _wrap( with stdout_cm, stderr_cm: ret = record(fn)(*args_) ret_val_.put(ret) - queue_finished_reading_event.wait() class MultiprocessContext(PContext): @@ -333,9 +331,6 @@ class MultiprocessContext(PContext): # see comments in ``join()`` for what this is self._return_values: Dict[int, Any] = {} self._pc: Optional[mp.ProcessContext] = None - # Note: set method should ONLY be invoked for the use case when all processes finished - # successfully. If any process died on event.wait() calling set() method will deadlock. - self._worker_finished_event = mp.get_context(self.start_method).Event() def _start(self): if self._pc: @@ -352,7 +347,6 @@ class MultiprocessContext(PContext): self.stdouts, self.stderrs, self._ret_vals, - self._worker_finished_event, ), nprocs=self.nprocs, join=False, @@ -360,19 +354,15 @@ class MultiprocessContext(PContext): start_method=self.start_method, ) - def _is_done(self) -> bool: - return len(self._return_values) == self.nprocs - def _poll(self) -> Optional[RunProcsResult]: assert self._pc is not None # assertion for mypy type checker try: - # torch.mp.ProcessContext Throws an Exception if some/all of - # worker processes failed + # torch.mp.ProcessContext returns True if all the workers have + # successfully finished, False if some/all are still running + # and throws an Exception if some/all of them failed # timeout < 0 checks worker status and return immediately - # Join will never return success since we use mp.Event to wait - # for all processes to finish. - self._pc.join(-1) + done = self._pc.join(-1) # IMPORTANT: we use multiprocessing.Queue to carry worker return values # back to the parent, the worker process will wait before terminating @@ -386,12 +376,8 @@ class MultiprocessContext(PContext): # save the return values temporarily into a member var self._return_values[local_rank] = return_queue.get() - if self._is_done(): + if done: # we should ALWAYS have ALL the return values when all the processes are done - self._worker_finished_event.set() - # Wait untill all processes are finished. At this point workers finished executing - # user function - self._pc.join() _validate_full_rank( self._return_values, self.nprocs, "return_value queue" )