[torchelastic] Improve process termination logic (#61602)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61602

The diff introduces signal handlers and SignalException that is raised when the agent process receives SIGTERM or SIGINT.

When any of these signals received, the termination handler will raise the `SignalException`. The exception will then be processed by the main agent loop. The `shutdown(signum)` will be invoked, that would propagate the received signal to the child processes. The default 30 seconds timeout introduced: if child processes will not be able gracefully terminate during this timeout, the agent process would kill the processes via SIGKILL.

Test Plan: unittests, sandcastle

Reviewed By: cbalioglu

Differential Revision: D29671783

fbshipit-source-id: 3dbca2125676dc18d417cc3e3bb0301fdd42737a
This commit is contained in:
Aliaksandr Ivanou 2021-07-23 10:58:39 -07:00 committed by Facebook GitHub Bot
parent e42360d56f
commit 0c55f1bdec
7 changed files with 295 additions and 53 deletions

View File

@ -7,6 +7,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import signal
import unittest import unittest
import uuid import uuid
from typing import Any, Dict from typing import Any, Dict
@ -22,6 +23,7 @@ from torch.distributed.elastic.agent.server.api import (
_get_fq_hostname, _get_fq_hostname,
_RoleInstanceInfo, _RoleInstanceInfo,
) )
from torch.distributed.elastic.multiprocessing import SignalException
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
from torch.distributed.elastic.utils.distributed import get_free_port from torch.distributed.elastic.utils.distributed import get_free_port
@ -550,6 +552,20 @@ class SimpleElasticAgentTest(unittest.TestCase):
self.assertEqual(spec.role, actual_event.metadata["role"]) self.assertEqual(spec.role, actual_event.metadata["role"])
self.assertEqual(2, actual_event.metadata["agent_restarts"]) self.assertEqual(2, actual_event.metadata["agent_restarts"])
@patch("torch.distributed.elastic.agent.server.api.put_metric")
@patch.object(TestAgent, "_invoke_run")
def test_agent_process_signal_exception(self, invoke_run, put_metric_mock):
spec = self._get_worker_spec(max_restarts=0)
agent = TestAgent(spec)
invoke_run.side_effect = SignalException(
"signal exception", sigval=signal.SIGTERM
)
with patch.object(agent, "_shutdown") as shutdown_mock:
with self.assertRaises(SignalException):
agent.run()
args, _ = shutdown_mock.call_args
self.assertEqual(signal.SIGTERM, args[0])
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -15,7 +15,7 @@ import tempfile
import time import time
import unittest import unittest
from itertools import product from itertools import product
from typing import Dict, List from typing import Dict, List, Union, Callable
from unittest import mock from unittest import mock
from unittest.mock import patch from unittest.mock import patch
@ -24,6 +24,7 @@ import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
from torch.distributed.elastic.multiprocessing.api import ( from torch.distributed.elastic.multiprocessing.api import (
MultiprocessContext, MultiprocessContext,
SignalException,
RunProcsResult, RunProcsResult,
Std, Std,
_validate_full_rank, _validate_full_rank,
@ -173,6 +174,52 @@ def redirects_all() -> List[Std]:
] ]
def bin(name: str):
dir = os.path.dirname(__file__)
return os.path.join(dir, "bin", name)
def wait_fn(wait_time: int = 300) -> None:
time.sleep(wait_time)
print("Finished waiting")
def start_processes_zombie_test(
idx: int,
entrypoint: Union[str, Callable],
mp_queue: mp.Queue,
log_dir: str,
nproc: int = 2,
) -> None:
"""
Starts processes
"""
args = {}
envs = {}
for idx in range(nproc):
args[idx] = ()
envs[idx] = {}
pc = start_processes(
name="zombie_test",
entrypoint=entrypoint,
args=args,
envs=envs,
log_dir=log_dir,
redirects=Std.NONE,
)
my_pid = os.getpid()
mp_queue.put(my_pid)
for child_pid in pc.pids().values():
mp_queue.put(child_pid)
try:
pc.wait(period=1, timeout=300)
except SignalException as e:
pc.close(e.sigval)
@unittest.skipIf( @unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS, TEST_WITH_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS,
"tests incompatible with tsan or asan", "tests incompatible with tsan or asan",
@ -294,6 +341,19 @@ class StartProcessesTest(unittest.TestCase):
self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_subprocess_context_close(self):
pc = start_processes(
name="sleep",
entrypoint=bin("zombie_test.py"),
args={0: (1,)},
envs={0: {}},
log_dir=self.log_dir(),
)
pids = pc.pids()
pc.close()
self.assert_pids_noexist(pids)
def test_function_with_tensor(self): def test_function_with_tensor(self):
for start_method in self._start_methods: for start_method in self._start_methods:
pc = start_processes( pc = start_processes(
@ -395,15 +455,11 @@ class StartProcessesTest(unittest.TestCase):
# start_processes as binary tests # start_processes as binary tests
######################################## ########################################
def bin(self, name: str):
dir = os.path.dirname(__file__)
return os.path.join(dir, "bin", name)
def test_binary_exit(self): def test_binary_exit(self):
FAIL = 138 FAIL = 138
pc = start_processes( pc = start_processes(
name="echo", name="echo",
entrypoint=self.bin("echo1.py"), entrypoint=bin("echo1.py"),
args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")}, args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
log_dir=self.log_dir(), log_dir=self.log_dir(),
@ -429,7 +485,7 @@ class StartProcessesTest(unittest.TestCase):
def test_binary_raises(self): def test_binary_raises(self):
pc = start_processes( pc = start_processes(
name="echo", name="echo",
entrypoint=self.bin("echo2.py"), entrypoint=bin("echo2.py"),
args={0: ("--raises", "true", "foo"), 1: ("bar",)}, args={0: ("--raises", "true", "foo"), 1: ("bar",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
log_dir=self.log_dir(), log_dir=self.log_dir(),
@ -540,7 +596,7 @@ class StartProcessesListTest(StartProcessesTest):
with self.subTest(redirs=redirs): with self.subTest(redirs=redirs):
pc = start_processes( pc = start_processes(
name="echo", name="echo",
entrypoint=self.bin("echo1.py"), entrypoint=bin("echo1.py"),
args={0: ("hello",), 1: ("hello",)}, args={0: ("hello",), 1: ("hello",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
log_dir=self.log_dir(), log_dir=self.log_dir(),
@ -572,7 +628,7 @@ class StartProcessesListTest(StartProcessesTest):
def test_binary_redirect_and_tee(self): def test_binary_redirect_and_tee(self):
pc = start_processes( pc = start_processes(
name="trainer", name="trainer",
entrypoint=self.bin("echo1.py"), entrypoint=bin("echo1.py"),
args={0: ("hello",), 1: ("world",)}, args={0: ("hello",), 1: ("world",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
log_dir=self.log_dir(), log_dir=self.log_dir(),
@ -631,7 +687,7 @@ class StartProcessesNotCITest(StartProcessesTest):
def test_binary_signal(self): def test_binary_signal(self):
pc = start_processes( pc = start_processes(
name="echo", name="echo",
entrypoint=self.bin("echo3.py"), entrypoint=bin("echo3.py"),
args={0: ("--segfault", "true", "foo"), 1: ("bar",)}, args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
log_dir=self.log_dir(), log_dir=self.log_dir(),
@ -755,6 +811,43 @@ class StartProcessesNotCITest(StartProcessesTest):
self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_no_zombie_process_binary(self):
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
for s in signals:
self._test_zombie_workflow(bin("zombie_test.py"), s)
def test_no_zombie_process_function(self):
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
for s in signals:
self._test_zombie_workflow(wait_fn, s)
def _test_zombie_workflow(
self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
) -> None:
mp_queue = mp.get_context("spawn").Queue()
child_nproc = 2
ctx = mp.spawn(
start_processes_zombie_test,
nprocs=1,
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
join=False,
)
total_processes = child_nproc + 1
pids = []
for _ in range(total_processes):
pids.append(mp_queue.get(timeout=120))
parent_pid = pids[0]
child_pids = pids[1:]
os.kill(parent_pid, signal.SIGTERM)
# Wait to give time for signal handlers to finish work
time.sleep(5)
for child_pid in child_pids:
# Killing parent should kill all children, we expect that each call to
# os.kill would raise OSError
with self.assertRaises(OSError):
os.kill(child_pid, 0)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -0,0 +1,14 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import time
if __name__ == "__main__":
time.sleep(600)
print("finished work")

View File

@ -10,6 +10,7 @@ import abc
import functools import functools
import json import json
import os import os
import signal
import socket import socket
import time import time
import traceback import traceback
@ -24,7 +25,11 @@ import torch.distributed.elastic.utils.store as store_util
from torch.distributed import Store from torch.distributed import Store
from torch.distributed.elastic.events import Event, EventSource, record from torch.distributed.elastic.events import Event, EventSource, record
from torch.distributed.elastic.metrics import prof, put_metric from torch.distributed.elastic.metrics import prof, put_metric
from torch.distributed.elastic.multiprocessing import ProcessFailure, Std from torch.distributed.elastic.multiprocessing import (
ProcessFailure,
Std,
SignalException,
)
from torch.distributed.elastic.utils.logging import get_logger from torch.distributed.elastic.utils.logging import get_logger
@ -488,9 +493,12 @@ class SimpleElasticAgent(ElasticAgent):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def _shutdown(self) -> None: def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
""" """
Cleans up any resources that were allocated during the agent's work. Cleans up any resources that were allocated during the agent's work.
Args:
death_sig: Signal to send to the child process, SIGTERM is default
""" """
raise NotImplementedError() raise NotImplementedError()
@ -696,16 +704,23 @@ class SimpleElasticAgent(ElasticAgent):
@prof @prof
def run(self, role: str = DEFAULT_ROLE) -> RunResult: def run(self, role: str = DEFAULT_ROLE) -> RunResult:
start_time = time.monotonic() start_time = time.monotonic()
shutdown_called: bool = False
try: try:
result = self._invoke_run(role) result = self._invoke_run(role)
self._total_execution_time = int(time.monotonic() - start_time) self._total_execution_time = int(time.monotonic() - start_time)
self._record_metrics(result) self._record_metrics(result)
self._record_worker_events(result) self._record_worker_events(result)
return result return result
except SignalException as e:
log.warning(f"Received {e.sigval} death signal, shutting down workers")
self._shutdown(e.sigval)
shutdown_called = True
raise
finally: finally:
if not shutdown_called:
self._shutdown()
# record the execution time in case there were any exceptions during run. # record the execution time in case there were any exceptions during run.
self._total_execution_time = int(time.monotonic() - start_time) self._total_execution_time = int(time.monotonic() - start_time)
self._shutdown()
def get_agent_status_event(self, state: WorkerState) -> Event: def get_agent_status_event(self, state: WorkerState) -> Event:
raw_error = traceback.format_exc() if state == WorkerState.FAILED else None raw_error = traceback.format_exc() if state == WorkerState.FAILED else None
@ -891,6 +906,9 @@ class SimpleElasticAgent(ElasticAgent):
log.info( log.info(
f"Done waiting for other agents. Elapsed: {time.time() - start} seconds" f"Done waiting for other agents. Elapsed: {time.time() - start} seconds"
) )
except SignalException as e:
log.warn(f"Got termination signal: {e.sigval}")
raise
except Exception: except Exception:
log.exception( log.exception(
f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds" f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds"

View File

@ -9,6 +9,7 @@
import os import os
import shutil import shutil
import signal
import tempfile import tempfile
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
@ -183,9 +184,9 @@ class LocalElasticAgent(SimpleElasticAgent):
return self._pcontext.pids() return self._pcontext.pids()
def _shutdown(self) -> None: def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
if self._pcontext: if self._pcontext:
self._pcontext.close() self._pcontext.close(death_sig)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`. # `torch.distributed.elastic.metrics.prof`.

View File

@ -72,6 +72,7 @@ from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
ProcessFailure, ProcessFailure,
RunProcsResult, RunProcsResult,
Std, Std,
SignalException,
SubprocessContext, SubprocessContext,
_validate_full_rank, _validate_full_rank,
to_map, to_map,

View File

@ -18,6 +18,7 @@ from contextlib import AbstractContextManager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import IntFlag from enum import IntFlag
from multiprocessing import synchronize from multiprocessing import synchronize
from types import FrameType
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
import torch.multiprocessing as mp import torch.multiprocessing as mp
@ -35,6 +36,50 @@ IS_MACOS = sys.platform == "darwin"
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class SignalException(Exception):
"""
Exception is raised inside the torchelastic agent process by the termination handler
if the death signal got received by the process.
"""
def __init__(self, msg: str, sigval: signal.Signals) -> None:
super().__init__(msg)
self.sigval = sigval
def _terminate_process_handler(signum: int, frame: FrameType) -> None:
"""Termination handler that raises exceptions on the main process.
When the process receives death signal(SIGTERM, SIGINT), this termination handler will
be invoked. It raises the ``SignalException`` exception that should be processed by the
user code. Python does not terminate process after the termination handler is finished,
so the exception should not be silently ignored, otherwise the process will never
be terminated.
"""
sigval = signal.Signals(signum)
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
def _get_kill_signal() -> signal.Signals:
"""
Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.
"""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGKILL
def _get_default_signal() -> signal.Signals:
"""
Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.
"""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGTERM
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str): def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
actual_keys = set(d.keys()) actual_keys = set(d.keys())
expected_keys = set(range(nprocs)) expected_keys = set(range(nprocs))
@ -185,6 +230,11 @@ class PContext(abc.ABC):
""" """
Start processes using parameters defined in the constructor. Start processes using parameters defined in the constructor.
""" """
signal.signal(signal.SIGTERM, _terminate_process_handler)
signal.signal(signal.SIGINT, _terminate_process_handler)
if not IS_WINDOWS:
signal.signal(signal.SIGHUP, _terminate_process_handler)
signal.signal(signal.SIGQUIT, _terminate_process_handler)
self._start() self._start()
self._stdout_tail.start() self._stdout_tail.start()
self._stderr_tail.start() self._stderr_tail.start()
@ -214,6 +264,23 @@ class PContext(abc.ABC):
on timeout expiry. Negative timeout values are interpreted as "wait-forever". on timeout expiry. Negative timeout values are interpreted as "wait-forever".
A timeout value of zero simply queries the status of the processes (e.g. equivalent A timeout value of zero simply queries the status of the processes (e.g. equivalent
to a poll). to a poll).
..note: Multiprocesing library registers SIGTERM and SIGINT signal handlers that raise
``SignalException`` when the signals received. It is up to the consumer of the code
to properly handle the exception. It is important not to swallow the exception otherwise
the process would not terminate. Example of the typical workflow can be:
.. code-block:: python
pc = start_processes(...)
try:
pc.wait(1)
.. do some other work
except SignalException as e:
pc.shutdown(e.sigval, timeout=30)
If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating
received signal. If child processes will not terminate in the timeout time, the process will send
the SIGKILL.
""" """
if timeout == 0: if timeout == 0:
@ -239,15 +306,28 @@ class PContext(abc.ABC):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def _close(self) -> None: def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
r""" r"""
Terminates all processes managed by this context and cleans up any Terminates all processes managed by this context and cleans up any
meta resources (e.g. redirect, error_file files). meta resources (e.g. redirect, error_file files).
""" """
raise NotImplementedError() raise NotImplementedError()
def close(self) -> None: def close(
self._close() self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
) -> None:
r"""
Terminates all processes managed by this context and cleans up any
meta resources (e.g. redirect, error_file files).
Args:
death_sig: Death signal to terminate porcesses.
timeout: Time to wait for processes to finish, if process is
still alive after this time, it will be terminated via SIGKILL.
"""
if not death_sig:
death_sig = _get_default_signal()
self._close(death_sig=death_sig, timeout=timeout)
if self._stdout_tail: if self._stdout_tail:
self._stdout_tail.stop() self._stdout_tail.stop()
if self._stderr_tail: if self._stderr_tail:
@ -447,10 +527,34 @@ class MultiprocessContext(PContext):
assert self._pc is not None # assertion for mypy type checking assert self._pc is not None # assertion for mypy type checking
return {local_rank: pid for local_rank, pid in enumerate(self._pc.pids())} return {local_rank: pid for local_rank, pid in enumerate(self._pc.pids())}
def _close(self) -> None: def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
if self._pc: if not self._pc:
return
log.warning(f"Closing processes via signal {death_sig}")
for proc in self._pc.processes: for proc in self._pc.processes:
proc.terminate() try:
os.kill(proc.pid, death_sig)
except ProcessLookupError:
# If the process exited because of some reason,
# `ProcessLookupError` will be rasied, it is safe to ignore it.
pass
end = time.monotonic() + timeout
for proc in self._pc.processes:
time_to_wait = end - time.monotonic()
if time_to_wait <= 0:
break
proc.join(time_to_wait)
log.warning(
f"Unable to shutdown processes via {death_sig}, forcefully exitting via {_get_kill_signal()}"
)
for proc in self._pc.processes:
if proc.is_alive():
try:
os.kill(proc.pid, _get_kill_signal())
except ProcessLookupError:
# If the process exited because of some reason,
# `ProcessLookupError` will be rasied, it is safe to ignore it.
pass
proc.join() proc.join()
@ -465,7 +569,6 @@ class SubprocessHandler:
entrypoint: str, entrypoint: str,
args: Tuple, args: Tuple,
env: Dict[str, str], env: Dict[str, str],
preexec_fn: Optional[Callable],
stdout: str, stdout: str,
stderr: str, stderr: str,
): ):
@ -476,46 +579,29 @@ class SubprocessHandler:
env_vars.update(env) env_vars.update(env)
args_str = (entrypoint, *[str(e) for e in args]) args_str = (entrypoint, *[str(e) for e in args])
self.proc: subprocess.Popen = self._popen(args_str, env_vars, preexec_fn) self.proc: subprocess.Popen = self._popen(args_str, env_vars)
def _popen(
self, args: Tuple, env: Dict[str, str], preexec_fn: Optional[Callable]
) -> subprocess.Popen:
if IS_WINDOWS:
# Reset preexec_fn on windows, since windows does not support it
preexec_fn = None
def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen:
return subprocess.Popen( return subprocess.Popen(
# pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes], # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
# _PathLike[str], bytes, str]], bytes, str]` for 1st param but got # _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
# `Tuple[str, *Tuple[Any, ...]]`. # `Tuple[str, *Tuple[Any, ...]]`.
args=args, args=args,
env=env, env=env,
preexec_fn=preexec_fn,
stdout=self._stdout, stdout=self._stdout,
stderr=self._stderr, stderr=self._stderr,
) )
def close(self): def close(self, death_sig: Optional[signal.Signals] = None) -> None:
self.proc.terminate() if not death_sig:
self.proc.wait() death_sig = _get_default_signal()
self.proc.send_signal(death_sig)
if self._stdout: if self._stdout:
self._stdout.close() self._stdout.close()
if self._stderr: if self._stderr:
self._stderr.close() self._stderr.close()
def _pr_set_pdeathsig() -> None:
"""
Sets PR_SET_PDEATHSIG to ensure a child process is
terminated appropriately.
See http://stackoverflow.com/questions/1884941/ for more information.
For libc.so.6 read http://www.linux-m68k.org/faq/glibcinfo.html
"""
mp._prctl_pr_set_pdeathsig(signal.SIGTERM) # type: ignore[attr-defined]
class SubprocessContext(PContext): class SubprocessContext(PContext):
""" """
``PContext`` holding worker processes invoked as a binary. ``PContext`` holding worker processes invoked as a binary.
@ -560,7 +646,6 @@ class SubprocessContext(PContext):
entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str
args=self.args[local_rank], args=self.args[local_rank],
env=self.envs[local_rank], env=self.envs[local_rank],
preexec_fn=_pr_set_pdeathsig,
stdout=self.stdouts[local_rank], stdout=self.stdouts[local_rank],
stderr=self.stderrs[local_rank], stderr=self.stderrs[local_rank],
) )
@ -616,7 +701,21 @@ class SubprocessContext(PContext):
for local_rank, sh in self.subprocess_handlers.items() for local_rank, sh in self.subprocess_handlers.items()
} }
def _close(self) -> None: def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
if self.subprocess_handlers: if not self.subprocess_handlers:
return
log.warning(f"Sending processes {death_sig}")
for handler in self.subprocess_handlers.values(): for handler in self.subprocess_handlers.values():
handler.close() handler.close(death_sig=death_sig)
end = time.monotonic() + timeout
for handler in self.subprocess_handlers.values():
time_to_wait = end - time.monotonic()
if time_to_wait <= 0:
break
handler.proc.wait(time_to_wait)
log.warning(
f"Unable to shutdown processes via {death_sig}, forcefully exitting via {_get_kill_signal()}"
)
for handler in self.subprocess_handlers.values():
handler.close(death_sig=_get_kill_signal())
handler.proc.wait()