mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66182 closes https://github.com/pytorch/pytorch/issues/63174 Does a few things: 1. adds hostname to the error report 2. moves the "root cause" section to the end (presumably since the logs are being "tailed" we want the root cause to appear at the end) 3. moves redundant error info logging to debug 4. makes the border max 60 char in length and justifies left for the header NOTE: YOU HAVE TO annotate your main function with torch.distributed.elastic.multiprocessing.errors.record, otherwise no traceback is printed (this is because python exception propagation does NOT work out of the both for IPC - hence the extra record annotation). Test Plan: Sample ``` ============================================================ run_script_path FAILED ------------------------------------------------------------ Failures: <NO_OTHER_FAILURES> ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2021-10-05_17:37:22 host : devvm4955.prn0.facebook.com rank : 0 (local_rank: 0) exitcode : 1 (pid: 3296201) error_file: /home/kiuk/tmp/elastic/none_3_lsytqe/attempt_0/0/error.json traceback : Traceback (most recent call last): File "/tmp/jetter.xr3_x6qq/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 372, in wrapper return f(*args, **kwargs) File "main.py", line 28, in main raise RuntimeError(args.throws) RuntimeError: foobar ============================================================ ``` Reviewed By: cbalioglu, aivanou Differential Revision: D31416492 fbshipit-source-id: 0aeaf6e634e23ce0ea7f6a03b12c8a9ac57246e9
861 lines
32 KiB
Python
861 lines
32 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import ctypes
|
|
import multiprocessing
|
|
import os
|
|
import shutil
|
|
import signal
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import unittest
|
|
from itertools import product
|
|
from typing import Callable, Dict, List, Union
|
|
from unittest import mock
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
|
|
from torch.distributed.elastic.multiprocessing.api import (
|
|
MultiprocessContext,
|
|
RunProcsResult,
|
|
SignalException,
|
|
Std,
|
|
_validate_full_rank,
|
|
_wrap,
|
|
to_map,
|
|
)
|
|
from torch.distributed.elastic.multiprocessing.errors.error_handler import _write_error
|
|
from torch.testing._internal.common_utils import (
|
|
IS_IN_CI,
|
|
IS_MACOS,
|
|
IS_WINDOWS,
|
|
NO_MULTIPROCESSING_SPAWN,
|
|
TEST_WITH_ASAN,
|
|
TEST_WITH_DEV_DBG_ASAN,
|
|
TEST_WITH_TSAN,
|
|
run_tests,
|
|
sandcastle_skip_if,
|
|
)
|
|
|
|
|
|
class RunProcResultsTest(unittest.TestCase):
|
|
def setUp(self):
|
|
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.test_dir)
|
|
|
|
def test_is_failed(self):
|
|
pr_success = RunProcsResult(return_values={0: "a", 1: "b"})
|
|
self.assertFalse(pr_success.is_failed())
|
|
|
|
fail0 = ProcessFailure(
|
|
local_rank=0, pid=998, exitcode=1, error_file="ignored.json"
|
|
)
|
|
pr_fail = RunProcsResult(failures={0: fail0})
|
|
self.assertTrue(pr_fail.is_failed())
|
|
|
|
@patch("torch.distributed.elastic.multiprocessing.errors.log")
|
|
def test_get_failures(self, log_mock):
|
|
with mock.patch("time.time", side_effect=[3, 2, 1]):
|
|
error_file0 = os.path.join(self.test_dir, "error0.json")
|
|
error_file1 = os.path.join(self.test_dir, "error1.json")
|
|
_write_error(RuntimeError("error 0"), error_file0)
|
|
_write_error(RuntimeError("error 1"), error_file1)
|
|
|
|
fail0 = ProcessFailure(
|
|
local_rank=0, pid=997, exitcode=1, error_file=error_file0
|
|
)
|
|
fail1 = ProcessFailure(
|
|
local_rank=1, pid=998, exitcode=3, error_file=error_file1
|
|
)
|
|
fail2 = ProcessFailure(
|
|
local_rank=2, pid=999, exitcode=15, error_file="no_exist.json"
|
|
)
|
|
|
|
self.assertEqual(3, fail0.timestamp)
|
|
self.assertEqual(2, fail1.timestamp)
|
|
self.assertEqual(1, fail2.timestamp)
|
|
|
|
|
|
class StdTest(unittest.TestCase):
|
|
def test_from_value(self):
|
|
self.assertEqual(Std.NONE, Std.from_str("0"))
|
|
self.assertEqual(Std.OUT, Std.from_str("1"))
|
|
self.assertEqual(Std.ERR, Std.from_str("2"))
|
|
self.assertEqual(Std.ALL, Std.from_str("3"))
|
|
|
|
def test_from_value_map(self):
|
|
self.assertEqual({0: Std.OUT}, Std.from_str("0:1"))
|
|
self.assertEqual({0: Std.OUT, 1: Std.OUT}, Std.from_str("0:1,1:1"))
|
|
|
|
def test_from_str_bad_input(self):
|
|
bad_inputs = ["0:1,", "11", "0:1,1", "1,0:1"]
|
|
for bad in bad_inputs:
|
|
with self.subTest(bad=bad):
|
|
with self.assertRaises(ValueError):
|
|
Std.from_str(bad)
|
|
|
|
|
|
def echo0(msg: str) -> None:
|
|
"""
|
|
void function
|
|
"""
|
|
print(msg)
|
|
|
|
|
|
def echo1(msg: str, exitcode: int = 0) -> str:
|
|
"""
|
|
returns ``msg`` or exits with the given exitcode (if nonzero)
|
|
"""
|
|
|
|
rank = int(os.environ["RANK"])
|
|
if exitcode != 0:
|
|
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
|
sys.exit(exitcode)
|
|
else:
|
|
print(f"{msg} stdout from {rank}")
|
|
print(f"{msg} stderr from {rank}", file=sys.stderr)
|
|
return f"{msg}_{rank}"
|
|
|
|
|
|
def echo2(msg: str, fail: bool = False) -> str:
|
|
"""
|
|
returns ``msg`` or raises a RuntimeError if ``fail`` is set
|
|
"""
|
|
if fail:
|
|
raise RuntimeError(msg)
|
|
return msg
|
|
|
|
|
|
def echo_large(size: int) -> Dict[int, str]:
|
|
"""
|
|
returns a large output ({0: test0", 1: "test1", ..., (size-1):f"test{size-1}"})
|
|
"""
|
|
out = {}
|
|
for idx in range(0, size):
|
|
out[idx] = f"test{idx}"
|
|
return out
|
|
|
|
|
|
def echo3(msg: str, fail: bool = False) -> str:
|
|
"""
|
|
returns ``msg`` or induces a SIGSEGV if ``fail`` is set
|
|
"""
|
|
if fail:
|
|
ctypes.string_at(0)
|
|
return msg
|
|
|
|
|
|
def dummy_compute() -> torch.Tensor:
|
|
"""
|
|
returns a predefined size random Tensor
|
|
"""
|
|
return torch.rand(100, 100)
|
|
|
|
|
|
def redirects_oss_test() -> List[Std]:
|
|
return [
|
|
Std.NONE,
|
|
]
|
|
|
|
|
|
def redirects_all() -> List[Std]:
|
|
return [
|
|
Std.NONE,
|
|
Std.OUT,
|
|
Std.ERR,
|
|
Std.ALL,
|
|
]
|
|
|
|
|
|
def bin(name: str):
|
|
dir = os.path.dirname(__file__)
|
|
return os.path.join(dir, "bin", name)
|
|
|
|
|
|
def wait_fn(wait_time: int = 300) -> None:
|
|
time.sleep(wait_time)
|
|
print("Finished waiting")
|
|
|
|
|
|
def start_processes_zombie_test(
|
|
idx: int,
|
|
entrypoint: Union[str, Callable],
|
|
mp_queue: mp.Queue,
|
|
log_dir: str,
|
|
nproc: int = 2,
|
|
) -> None:
|
|
"""
|
|
Starts processes
|
|
"""
|
|
|
|
args = {}
|
|
envs = {}
|
|
for idx in range(nproc):
|
|
args[idx] = ()
|
|
envs[idx] = {}
|
|
|
|
pc = start_processes(
|
|
name="zombie_test",
|
|
entrypoint=entrypoint,
|
|
args=args,
|
|
envs=envs,
|
|
log_dir=log_dir,
|
|
redirects=Std.NONE,
|
|
)
|
|
my_pid = os.getpid()
|
|
mp_queue.put(my_pid)
|
|
for child_pid in pc.pids().values():
|
|
mp_queue.put(child_pid)
|
|
|
|
try:
|
|
pc.wait(period=1, timeout=300)
|
|
except SignalException as e:
|
|
pc.close(e.sigval)
|
|
|
|
|
|
# tests incompatible with tsan or asan
|
|
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|
|
|
class StartProcessesTest(unittest.TestCase):
|
|
def setUp(self):
|
|
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
|
|
self._start_methods = ["spawn"]
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.test_dir)
|
|
|
|
def log_dir(self):
|
|
return tempfile.mkdtemp(dir=self.test_dir)
|
|
|
|
def assert_in_file(self, expected: List[str], filename: str) -> None:
|
|
expected = [f"{line.rstrip()}\n" for line in expected]
|
|
with open(filename, "r") as fp:
|
|
actual = fp.readlines()
|
|
for line in expected:
|
|
self.assertIn(line, actual)
|
|
|
|
def assert_pids_noexist(self, pids: Dict[int, int]):
|
|
for local_rank, pid in pids.items():
|
|
with self.assertRaises(
|
|
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
|
|
):
|
|
os.kill(pid, 0)
|
|
|
|
def test_to_map(self):
|
|
local_world_size = 2
|
|
self.assertEqual(
|
|
{0: Std.OUT, 1: Std.OUT}, to_map(Std.OUT, local_world_size)
|
|
)
|
|
self.assertEqual(
|
|
{0: Std.NONE, 1: Std.OUT}, to_map({1: Std.OUT}, local_world_size)
|
|
)
|
|
self.assertEqual(
|
|
{0: Std.ERR, 1: Std.OUT},
|
|
to_map({0: Std.ERR, 1: Std.OUT}, local_world_size),
|
|
)
|
|
|
|
def test_invalid_log_dir(self):
|
|
with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir:
|
|
cases = {
|
|
"does_not_exist": FileNotFoundError,
|
|
not_a_dir.name: NotADirectoryError,
|
|
# test_dir is not empty since we touched not_a_dir file
|
|
self.test_dir: RuntimeError,
|
|
}
|
|
|
|
for (log_dir, expected_error) in cases.items():
|
|
with self.subTest(log_dir=log_dir, expected_error=expected_error):
|
|
with self.assertRaises(expected_error):
|
|
start_processes(
|
|
name="echo",
|
|
entrypoint=echo1,
|
|
args={0: ("hello",)},
|
|
envs={0: {"RANK": "0"}},
|
|
log_dir=log_dir,
|
|
)
|
|
|
|
def test_args_env_len_mismatch(self):
|
|
cases = [
|
|
# 1 x args; 2 x envs
|
|
{
|
|
"args": {0: ("hello",)},
|
|
"envs": {0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
},
|
|
# 2 x args; 1 x envs
|
|
{
|
|
"args": {0: ("hello",), 1: ("world",)},
|
|
"envs": {0: {"RANK": "0"}},
|
|
},
|
|
]
|
|
|
|
for kwds in cases:
|
|
args = kwds["args"]
|
|
envs = kwds["envs"]
|
|
with self.subTest(args=args, envs=envs):
|
|
with self.assertRaises(RuntimeError):
|
|
start_processes(
|
|
name="echo",
|
|
entrypoint=echo1,
|
|
args=args,
|
|
envs=envs,
|
|
log_dir=self.log_dir(),
|
|
)
|
|
|
|
def test_pcontext_wait(self):
|
|
pc = start_processes(
|
|
name="sleep",
|
|
entrypoint=time.sleep,
|
|
args={0: (1,)},
|
|
envs={0: {}},
|
|
log_dir=self.log_dir(),
|
|
start_method="spawn",
|
|
)
|
|
|
|
self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
|
|
self.assertIsNotNone(pc.wait(period=0.1))
|
|
self.assertTrue(pc._stderr_tail.stopped())
|
|
self.assertTrue(pc._stdout_tail.stopped())
|
|
|
|
def test_multiprocess_context_close(self):
|
|
pc = start_processes(
|
|
name="sleep",
|
|
entrypoint=time.sleep,
|
|
args={0: (1,)},
|
|
envs={0: {}},
|
|
log_dir=self.log_dir(),
|
|
start_method="spawn",
|
|
)
|
|
|
|
pids = pc.pids()
|
|
pc.close()
|
|
self.assert_pids_noexist(pids)
|
|
self.assertTrue(pc._stderr_tail.stopped())
|
|
self.assertTrue(pc._stdout_tail.stopped())
|
|
|
|
def test_subprocess_context_close(self):
|
|
pc = start_processes(
|
|
name="sleep",
|
|
entrypoint=bin("zombie_test.py"),
|
|
args={0: (1,)},
|
|
envs={0: {}},
|
|
log_dir=self.log_dir(),
|
|
)
|
|
|
|
pids = pc.pids()
|
|
pc.close()
|
|
self.assert_pids_noexist(pids)
|
|
|
|
def test_function_with_tensor(self):
|
|
for start_method in self._start_methods:
|
|
pc = start_processes(
|
|
name="dummy_compute",
|
|
entrypoint=dummy_compute,
|
|
args={},
|
|
envs={},
|
|
log_dir=self.log_dir(),
|
|
start_method=start_method,
|
|
)
|
|
|
|
results = pc.wait()
|
|
self.assert_pids_noexist(pc.pids())
|
|
for return_value in results.return_values.values():
|
|
self.assertIsInstance(return_value, torch.Tensor)
|
|
self.assertEqual((100, 100), return_value.shape)
|
|
|
|
def test_void_function(self):
|
|
for start_method in self._start_methods:
|
|
with self.subTest(start_method=start_method):
|
|
pc = start_processes(
|
|
name="echo",
|
|
entrypoint=echo0,
|
|
args={0: ("hello",), 1: ("world",)},
|
|
envs={0: {}, 1: {}},
|
|
log_dir=self.log_dir(),
|
|
start_method=start_method,
|
|
)
|
|
|
|
results = pc.wait(period=0.1)
|
|
self.assertEqual({0: None, 1: None}, results.return_values)
|
|
|
|
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan")
|
|
def test_function_large_ret_val(self):
|
|
# python multiprocessing.queue module uses pipes and actually PipedQueues
|
|
# This means that if a single object is greater than a pipe size
|
|
# the writer process will block until reader process will start
|
|
# reading the pipe.
|
|
# This test makes a worker fn to return huge output, around ~10 MB
|
|
|
|
size = 200000
|
|
for start_method in self._start_methods:
|
|
with self.subTest(start_method=start_method):
|
|
pc = start_processes(
|
|
name="echo",
|
|
entrypoint=echo_large,
|
|
args={0: (size,), 1: (size,), 2: (size,), 3: (size,)},
|
|
envs={0: {}, 1: {}, 2: {}, 3: {}},
|
|
log_dir=self.log_dir(),
|
|
start_method=start_method,
|
|
)
|
|
|
|
results = pc.wait(period=0.1)
|
|
for i in range(pc.nprocs):
|
|
self.assertEqual(size, len(results.return_values[i]))
|
|
|
|
def test_function_raise(self):
|
|
"""
|
|
run 2x copies of echo2, raise an exception on the first
|
|
"""
|
|
RAISE = True
|
|
|
|
for start_method in self._start_methods:
|
|
with self.subTest(start_method=start_method):
|
|
log_dir = self.log_dir()
|
|
pc = start_processes(
|
|
name="echo",
|
|
entrypoint=echo2,
|
|
args={0: ("hello", RAISE), 1: ("world",)},
|
|
envs={0: {}, 1: {}},
|
|
log_dir=log_dir,
|
|
start_method=start_method,
|
|
)
|
|
|
|
results = pc.wait(period=0.1)
|
|
|
|
self.assert_pids_noexist(pc.pids())
|
|
self.assertEqual(1, len(results.failures))
|
|
self.assertFalse(results.return_values)
|
|
|
|
failure = results.failures[0]
|
|
error_file = failure.error_file
|
|
error_file_data = failure.error_file_data
|
|
|
|
self.assertEqual(1, failure.exitcode)
|
|
self.assertEqual("<N/A>", failure.signal_name())
|
|
self.assertEqual(pc.pids()[0], failure.pid)
|
|
self.assertEqual(
|
|
os.path.join(log_dir, "0", "error.json"), error_file
|
|
)
|
|
self.assertEqual(
|
|
int(error_file_data["message"]["extraInfo"]["timestamp"]),
|
|
int(failure.timestamp),
|
|
)
|
|
self.assertTrue(pc._stderr_tail.stopped())
|
|
self.assertTrue(pc._stdout_tail.stopped())
|
|
|
|
########################################
|
|
# start_processes as binary tests
|
|
########################################
|
|
|
|
def test_binary_exit(self):
|
|
FAIL = 138
|
|
pc = start_processes(
|
|
name="echo",
|
|
entrypoint=bin("echo1.py"),
|
|
args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")},
|
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
log_dir=self.log_dir(),
|
|
redirects={0: Std.ALL},
|
|
)
|
|
|
|
results = pc.wait(period=0.1)
|
|
|
|
self.assertTrue(results.is_failed())
|
|
self.assertEqual(1, len(results.failures))
|
|
|
|
failure = results.failures[0]
|
|
self.assertEqual(138, failure.exitcode)
|
|
self.assertEqual("<N/A>", failure.signal_name())
|
|
self.assertEqual("<NONE>", failure.error_file_data["message"])
|
|
self.assert_in_file([f"exit {FAIL} from 0"], results.stderrs[0])
|
|
self.assert_in_file([], results.stdouts[0])
|
|
self.assertFalse(results.stderrs[1])
|
|
self.assertFalse(results.stdouts[1])
|
|
self.assertTrue(pc._stderr_tail.stopped())
|
|
self.assertTrue(pc._stdout_tail.stopped())
|
|
|
|
def test_binary_raises(self):
|
|
pc = start_processes(
|
|
name="echo",
|
|
entrypoint=bin("echo2.py"),
|
|
args={0: ("--raises", "true", "foo"), 1: ("bar",)},
|
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
log_dir=self.log_dir(),
|
|
)
|
|
|
|
results = pc.wait(period=0.1)
|
|
|
|
self.assert_pids_noexist(pc.pids())
|
|
self.assertTrue(results.is_failed())
|
|
self.assertEqual(1, len(results.failures))
|
|
|
|
failure = results.failures[0]
|
|
self.assertEqual(1, failure.exitcode)
|
|
self.assertEqual("<NONE>", failure.error_file_data["message"])
|
|
self.assertEqual("<N/A>", failure.signal_name())
|
|
|
|
def test_binary_incorrect_entrypoint(self):
|
|
with self.assertRaises(FileNotFoundError):
|
|
start_processes(
|
|
name="echo",
|
|
entrypoint="does_not_exist.py",
|
|
args={0: ("foo"), 1: ("bar",)},
|
|
envs={0: {}, 1: {}},
|
|
log_dir=self.log_dir(),
|
|
)
|
|
|
|
def test_validate_full_rank(self):
|
|
with self.assertRaises(RuntimeError):
|
|
_validate_full_rank({}, 10, "")
|
|
|
|
@sandcastle_skip_if(
|
|
NO_MULTIPROCESSING_SPAWN,
|
|
"Disabled for environments that \
|
|
don't support multiprocessing with spawn start method",
|
|
)
|
|
def test_multiprocessing_context_poll_raises_exception(self):
|
|
mp_context = MultiprocessContext(
|
|
name="test_mp",
|
|
entrypoint=echo0,
|
|
args={0: (0, 1)},
|
|
envs={},
|
|
stdouts={0: {}},
|
|
stderrs={0: {}},
|
|
tee_stdouts={0: "tee_stdout"},
|
|
tee_stderrs={0: "tee_stderr"},
|
|
error_files={0: "test_file"},
|
|
start_method="spawn",
|
|
)
|
|
mp_context._pc = mock.Mock()
|
|
# Using mock since we cannot just set exitcode on process
|
|
mock_process = mock.Mock()
|
|
mock_process.exitcode = -1
|
|
mp_context._pc.processes = [mock_process]
|
|
e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123)
|
|
mp_context._pc.join.side_effect = e
|
|
with mock.patch.object(mp_context, "close"):
|
|
run_result = mp_context._poll()
|
|
self.assertEqual(1, len(run_result.failures))
|
|
failure = run_result.failures[0]
|
|
self.assertEqual(
|
|
"Signal 1 (SIGHUP) received by PID 123", failure.message
|
|
)
|
|
|
|
|
|
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
|
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|
|
|
class StartProcessesListTest(StartProcessesTest):
|
|
########################################
|
|
# start_processes as binary tests
|
|
########################################
|
|
def test_function(self):
|
|
for start_method, redirs in product(
|
|
self._start_methods, redirects_oss_test()
|
|
):
|
|
with self.subTest(start_method=start_method, redirs=redirs):
|
|
pc = start_processes(
|
|
name="echo",
|
|
entrypoint=echo1,
|
|
args={0: ("hello",), 1: ("hello",)},
|
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
log_dir=self.log_dir(),
|
|
start_method=start_method,
|
|
redirects=redirs,
|
|
)
|
|
|
|
results = pc.wait(period=0.1)
|
|
nprocs = pc.nprocs
|
|
|
|
self.assert_pids_noexist(pc.pids())
|
|
self.assertEqual(
|
|
{i: f"hello_{i}" for i in range(nprocs)}, results.return_values
|
|
)
|
|
|
|
for i in range(nprocs):
|
|
if redirs & Std.OUT != Std.OUT:
|
|
self.assertFalse(results.stdouts[i])
|
|
if redirs & Std.ERR != Std.ERR:
|
|
self.assertFalse(results.stderrs[i])
|
|
if redirs & Std.OUT == Std.OUT:
|
|
self.assert_in_file(
|
|
[f"hello stdout from {i}"], results.stdouts[i]
|
|
)
|
|
if redirs & Std.ERR == Std.ERR:
|
|
self.assert_in_file(
|
|
[f"hello stderr from {i}"], results.stderrs[i]
|
|
)
|
|
|
|
def test_binary(self):
|
|
for redirs in redirects_oss_test():
|
|
with self.subTest(redirs=redirs):
|
|
pc = start_processes(
|
|
name="echo",
|
|
entrypoint=bin("echo1.py"),
|
|
args={0: ("hello",), 1: ("hello",)},
|
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
log_dir=self.log_dir(),
|
|
redirects=redirs,
|
|
)
|
|
|
|
results = pc.wait(period=0.1)
|
|
|
|
self.assert_pids_noexist(pc.pids())
|
|
# currently binaries return {rank: None}
|
|
self.assertEqual(2, len(results.return_values))
|
|
self.assertFalse(results.is_failed())
|
|
|
|
nprocs = pc.nprocs
|
|
for i in range(nprocs):
|
|
if redirs & Std.OUT != Std.OUT:
|
|
self.assertFalse(results.stdouts[i])
|
|
if redirs & Std.ERR != Std.ERR:
|
|
self.assertFalse(results.stderrs[i])
|
|
if redirs & Std.OUT == Std.OUT:
|
|
self.assert_in_file(
|
|
[f"hello stdout from {i}"], results.stdouts[i]
|
|
)
|
|
if redirs & Std.ERR == Std.ERR:
|
|
self.assert_in_file(
|
|
[f"hello stderr from {i}"], results.stderrs[i]
|
|
)
|
|
|
|
def test_binary_redirect_and_tee(self):
|
|
pc = start_processes(
|
|
name="trainer",
|
|
entrypoint=bin("echo1.py"),
|
|
args={0: ("hello",), 1: ("world",)},
|
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
log_dir=self.log_dir(),
|
|
start_method="spawn",
|
|
redirects={0: Std.ERR, 1: Std.NONE},
|
|
tee={0: Std.OUT, 1: Std.ERR},
|
|
)
|
|
|
|
result = pc.wait()
|
|
|
|
self.assertFalse(result.is_failed())
|
|
self.assert_in_file(["hello stdout from 0"], pc.stdouts[0])
|
|
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
|
|
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
|
|
self.assertFalse(pc.stdouts[1])
|
|
self.assertTrue(pc._stderr_tail.stopped())
|
|
self.assertTrue(pc._stdout_tail.stopped())
|
|
|
|
|
|
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
|
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_IN_CI):
|
|
|
|
class StartProcessesNotCITest(StartProcessesTest):
|
|
def test_wrap_bad(self):
|
|
none = ""
|
|
stdout_log = os.path.join(self.test_dir, "stdout.log")
|
|
stderr_log = os.path.join(self.test_dir, "stderr.log")
|
|
redirs = [
|
|
(none, none),
|
|
(none, stderr_log),
|
|
(stdout_log, none),
|
|
(stdout_log, stderr_log),
|
|
]
|
|
|
|
for stdout_redir, stderr_redir in redirs:
|
|
queue = multiprocessing.SimpleQueue()
|
|
worker_finished_event_mock = mock.Mock()
|
|
_wrap(
|
|
local_rank=0,
|
|
fn=echo1,
|
|
args={0: ("hello",)},
|
|
envs={0: {"RANK": "0"}},
|
|
stdout_redirects={0: stdout_redir},
|
|
stderr_redirects={0: stderr_redir},
|
|
ret_vals={0: queue},
|
|
queue_finished_reading_event=worker_finished_event_mock,
|
|
)
|
|
self.assertEqual("hello_0", queue.get())
|
|
if stdout_redir:
|
|
self.assert_in_file(["hello stdout from 0"], stdout_log)
|
|
if stderr_redir:
|
|
self.assert_in_file(["hello stderr from 0"], stderr_log)
|
|
worker_finished_event_mock.wait.assert_called_once()
|
|
|
|
def test_binary_signal(self):
|
|
pc = start_processes(
|
|
name="echo",
|
|
entrypoint=bin("echo3.py"),
|
|
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
|
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
log_dir=self.log_dir(),
|
|
)
|
|
|
|
results = pc.wait(period=0.1)
|
|
|
|
self.assert_pids_noexist(pc.pids())
|
|
self.assertTrue(results.is_failed())
|
|
self.assertEqual(1, len(results.failures))
|
|
|
|
failure = results.failures[0]
|
|
self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
|
|
if TEST_WITH_ASAN or TEST_WITH_TSAN:
|
|
# ASAN/TSAN exit code is 1.
|
|
self.assertEqual("<N/A>", failure.signal_name())
|
|
else:
|
|
self.assertEqual("SIGSEGV", failure.signal_name())
|
|
self.assertEqual("<NONE>", failure.error_file_data["message"])
|
|
|
|
def test_function_redirect_and_tee(self):
|
|
for start_method in self._start_methods:
|
|
with self.subTest(start_method=start_method):
|
|
log_dir = self.log_dir()
|
|
pc = start_processes(
|
|
name="trainer",
|
|
entrypoint=echo1,
|
|
args={0: ("hello",), 1: ("world",)},
|
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
log_dir=log_dir,
|
|
start_method="spawn",
|
|
redirects={0: Std.ERR, 1: Std.NONE},
|
|
tee={0: Std.OUT, 1: Std.ERR},
|
|
)
|
|
|
|
result = pc.wait()
|
|
|
|
self.assertFalse(result.is_failed())
|
|
self.assert_in_file(["hello stdout from 0"], pc.stdouts[0])
|
|
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
|
|
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
|
|
self.assertFalse(pc.stdouts[1])
|
|
self.assertTrue(pc._stderr_tail.stopped())
|
|
self.assertTrue(pc._stdout_tail.stopped())
|
|
|
|
def test_function(self):
|
|
for start_method, redirs in product(self._start_methods, redirects_all()):
|
|
with self.subTest(start_method=start_method, redirs=redirs):
|
|
pc = start_processes(
|
|
name="echo",
|
|
entrypoint=echo1,
|
|
args={0: ("hello",), 1: ("hello",)},
|
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
log_dir=self.log_dir(),
|
|
start_method=start_method,
|
|
redirects=redirs,
|
|
)
|
|
|
|
results = pc.wait(period=0.1)
|
|
nprocs = pc.nprocs
|
|
|
|
self.assert_pids_noexist(pc.pids())
|
|
self.assertEqual(
|
|
{i: f"hello_{i}" for i in range(nprocs)}, results.return_values
|
|
)
|
|
|
|
for i in range(nprocs):
|
|
if redirs & Std.OUT != Std.OUT:
|
|
self.assertFalse(results.stdouts[i])
|
|
if redirs & Std.ERR != Std.ERR:
|
|
self.assertFalse(results.stderrs[i])
|
|
if redirs & Std.OUT == Std.OUT:
|
|
self.assert_in_file(
|
|
[f"hello stdout from {i}"], results.stdouts[i]
|
|
)
|
|
if redirs & Std.ERR == Std.ERR:
|
|
self.assert_in_file(
|
|
[f"hello stderr from {i}"], results.stderrs[i]
|
|
)
|
|
|
|
def test_function_exit(self):
|
|
"""
|
|
run 2x copies of echo1 fail (exit) the first
|
|
functions that exit from python do not generate an error file
|
|
(even if they are decorated with @record)
|
|
"""
|
|
|
|
FAIL = 138
|
|
for start_method in self._start_methods:
|
|
with self.subTest(start_method=start_method):
|
|
log_dir = self.log_dir()
|
|
pc = start_processes(
|
|
name="echo",
|
|
entrypoint=echo1,
|
|
args={0: ("hello", FAIL), 1: ("hello",)},
|
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
log_dir=log_dir,
|
|
start_method=start_method,
|
|
redirects={0: Std.ERR},
|
|
)
|
|
|
|
results = pc.wait(period=0.1)
|
|
|
|
self.assert_pids_noexist(pc.pids())
|
|
self.assertTrue(results.is_failed())
|
|
self.assertEqual(1, len(results.failures))
|
|
self.assertFalse(results.return_values)
|
|
|
|
failure = results.failures[0]
|
|
error_file = failure.error_file
|
|
|
|
self.assertEqual(FAIL, failure.exitcode)
|
|
self.assertEqual("<N/A>", failure.signal_name())
|
|
self.assertEqual(pc.pids()[0], failure.pid)
|
|
self.assertEqual("<N/A>", error_file)
|
|
self.assertEqual(
|
|
"To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html",
|
|
failure.message,
|
|
)
|
|
self.assertLessEqual(failure.timestamp, int(time.time()))
|
|
|
|
self.assert_in_file([f"exit {FAIL} from 0"], results.stderrs[0])
|
|
self.assertFalse(results.stdouts[0])
|
|
self.assertFalse(results.stderrs[1])
|
|
self.assertFalse(results.stdouts[1])
|
|
self.assertTrue(pc._stderr_tail.stopped())
|
|
self.assertTrue(pc._stdout_tail.stopped())
|
|
|
|
def test_no_zombie_process_binary(self):
|
|
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
|
for s in signals:
|
|
self._test_zombie_workflow(bin("zombie_test.py"), s)
|
|
|
|
def test_no_zombie_process_function(self):
|
|
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
|
for s in signals:
|
|
self._test_zombie_workflow(wait_fn, s)
|
|
|
|
def _test_zombie_workflow(
|
|
self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
|
|
) -> None:
|
|
mp_queue = mp.get_context("spawn").Queue()
|
|
child_nproc = 2
|
|
ctx = mp.spawn(
|
|
start_processes_zombie_test,
|
|
nprocs=1,
|
|
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
|
|
join=False,
|
|
)
|
|
total_processes = child_nproc + 1
|
|
pids = []
|
|
for _ in range(total_processes):
|
|
pids.append(mp_queue.get(timeout=120))
|
|
parent_pid = pids[0]
|
|
child_pids = pids[1:]
|
|
|
|
os.kill(parent_pid, signal.SIGTERM)
|
|
# Wait to give time for signal handlers to finish work
|
|
time.sleep(5)
|
|
for child_pid in child_pids:
|
|
# Killing parent should kill all children, we expect that each call to
|
|
# os.kill would raise OSError
|
|
with self.assertRaises(OSError):
|
|
os.kill(child_pid, 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|