diff --git a/test/test_multiprocessing_spawn.py b/test/test_multiprocessing_spawn.py index 7e6f8142839..d8483f115f5 100644 --- a/test/test_multiprocessing_spawn.py +++ b/test/test_multiprocessing_spawn.py @@ -12,65 +12,65 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOW import torch.multiprocessing as mp -def test_success_func(i): +def _test_success_func(i): pass -def test_success_single_arg_func(i, arg): +def _test_success_single_arg_func(i, arg): if arg: arg.put(i) -def test_exception_single_func(i, arg): +def _test_exception_single_func(i, arg): if i == arg: raise ValueError("legitimate exception from process %d" % i) time.sleep(1.0) -def test_exception_all_func(i): +def _test_exception_all_func(i): time.sleep(random.random() / 10) raise ValueError("legitimate exception from process %d" % i) -def test_terminate_signal_func(i): +def _test_terminate_signal_func(i): if i == 0: os.kill(os.getpid(), signal.SIGABRT) time.sleep(1.0) -def test_terminate_exit_func(i, arg): +def _test_terminate_exit_func(i, arg): if i == 0: sys.exit(arg) time.sleep(1.0) -def test_success_first_then_exception_func(i, arg): +def _test_success_first_then_exception_func(i, arg): if i == 0: return time.sleep(0.1) raise ValueError("legitimate exception") -def test_nested_child_body(i, ready_queue, nested_child_sleep): +def _test_nested_child_body(i, ready_queue, nested_child_sleep): ready_queue.put(None) time.sleep(nested_child_sleep) -def test_infinite_task(i): +def _test_infinite_task(i): while True: time.sleep(1) -def test_process_exit(idx): +def _test_process_exit(idx): sys.exit(12) -def test_nested(i, pids_queue, nested_child_sleep, start_method): +def _test_nested(i, pids_queue, nested_child_sleep, start_method): context = mp.get_context(start_method) nested_child_ready_queue = context.Queue() nprocs = 2 mp_context = mp.start_processes( - fn=test_nested_child_body, + fn=_test_nested_child_body, args=(nested_child_ready_queue, nested_child_sleep), nprocs=nprocs, join=False, @@ -91,10 +91,10 @@ class _TestMultiProcessing(object): start_method = None def test_success(self): - mp.start_processes(test_success_func, nprocs=2, start_method=self.start_method) + mp.start_processes(_test_success_func, nprocs=2, start_method=self.start_method) def test_success_non_blocking(self): - mp_context = mp.start_processes(test_success_func, nprocs=2, join=False, start_method=self.start_method) + mp_context = mp.start_processes(_test_success_func, nprocs=2, join=False, start_method=self.start_method) # After all processes (nproc=2) have joined it must return True mp_context.join(timeout=None) @@ -104,7 +104,7 @@ class _TestMultiProcessing(object): def test_first_argument_index(self): context = mp.get_context(self.start_method) queue = context.SimpleQueue() - mp.start_processes(test_success_single_arg_func, args=(queue,), nprocs=2, start_method=self.start_method) + mp.start_processes(_test_success_single_arg_func, args=(queue,), nprocs=2, start_method=self.start_method) self.assertEqual([0, 1], sorted([queue.get(), queue.get()])) def test_exception_single(self): @@ -114,14 +114,14 @@ class _TestMultiProcessing(object): Exception, "\nValueError: legitimate exception from process %d$" % i, ): - mp.start_processes(test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method) + mp.start_processes(_test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method) def test_exception_all(self): with self.assertRaisesRegex( Exception, "\nValueError: legitimate exception from process (0|1)$", ): - mp.start_processes(test_exception_all_func, nprocs=2, start_method=self.start_method) + mp.start_processes(_test_exception_all_func, nprocs=2, start_method=self.start_method) def test_terminate_signal(self): # SIGABRT is aliased with SIGIOT @@ -136,7 +136,7 @@ class _TestMultiProcessing(object): message = "process 0 terminated with exit code 22" with self.assertRaisesRegex(Exception, message): - mp.start_processes(test_terminate_signal_func, nprocs=2, start_method=self.start_method) + mp.start_processes(_test_terminate_signal_func, nprocs=2, start_method=self.start_method) def test_terminate_exit(self): exitcode = 123 @@ -144,7 +144,7 @@ class _TestMultiProcessing(object): Exception, "process 0 terminated with exit code %d" % exitcode, ): - mp.start_processes(test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method) + mp.start_processes(_test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method) def test_success_first_then_exception(self): exitcode = 123 @@ -152,18 +152,18 @@ class _TestMultiProcessing(object): Exception, "ValueError: legitimate exception", ): - mp.start_processes(test_success_first_then_exception_func, args=(exitcode,), nprocs=2, start_method=self.start_method) + mp.start_processes(_test_success_first_then_exception_func, args=(exitcode,), nprocs=2, start_method=self.start_method) @unittest.skipIf( sys.platform != "linux", "Only runs on Linux; requires prctl(2)", ) - def test_nested(self): + def _test_nested(self): context = mp.get_context(self.start_method) pids_queue = context.Queue() nested_child_sleep = 20.0 mp_context = mp.start_processes( - fn=test_nested, + fn=_test_nested, args=(pids_queue, nested_child_sleep, self.start_method), nprocs=1, join=False, @@ -197,18 +197,18 @@ class SpawnTest(TestCase, _TestMultiProcessing): def test_exception_raises(self): with self.assertRaises(mp.ProcessRaisedException): - mp.spawn(test_success_first_then_exception_func, args=(), nprocs=1) + mp.spawn(_test_success_first_then_exception_func, args=(), nprocs=1) def test_signal_raises(self): - context = mp.spawn(test_infinite_task, args=(), nprocs=1, join=False) + context = mp.spawn(_test_infinite_task, args=(), nprocs=1, join=False) for pid in context.pids(): os.kill(pid, signal.SIGTERM) with self.assertRaises(mp.ProcessExitedException): context.join() - def test_process_exited(self): + def _test_process_exited(self): with self.assertRaises(mp.ProcessExitedException) as e: - mp.spawn(test_process_exit, args=(), nprocs=1) + mp.spawn(_test_process_exit, args=(), nprocs=1) self.assertEqual(12, e.exit_code)