mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torchelastic] Improve process termination logic (#61602)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61602 The diff introduces signal handlers and SignalException that is raised when the agent process receives SIGTERM or SIGINT. When any of these signals received, the termination handler will raise the `SignalException`. The exception will then be processed by the main agent loop. The `shutdown(signum)` will be invoked, that would propagate the received signal to the child processes. The default 30 seconds timeout introduced: if child processes will not be able gracefully terminate during this timeout, the agent process would kill the processes via SIGKILL. Test Plan: unittests, sandcastle Reviewed By: cbalioglu Differential Revision: D29671783 fbshipit-source-id: 3dbca2125676dc18d417cc3e3bb0301fdd42737a
This commit is contained in:
parent
e42360d56f
commit
0c55f1bdec
|
|
@ -7,6 +7,7 @@
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import signal
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
@ -22,6 +23,7 @@ from torch.distributed.elastic.agent.server.api import (
|
||||||
_get_fq_hostname,
|
_get_fq_hostname,
|
||||||
_RoleInstanceInfo,
|
_RoleInstanceInfo,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.elastic.multiprocessing import SignalException
|
||||||
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
|
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
|
||||||
from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
|
from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
|
||||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||||
|
|
@ -550,6 +552,20 @@ class SimpleElasticAgentTest(unittest.TestCase):
|
||||||
self.assertEqual(spec.role, actual_event.metadata["role"])
|
self.assertEqual(spec.role, actual_event.metadata["role"])
|
||||||
self.assertEqual(2, actual_event.metadata["agent_restarts"])
|
self.assertEqual(2, actual_event.metadata["agent_restarts"])
|
||||||
|
|
||||||
|
@patch("torch.distributed.elastic.agent.server.api.put_metric")
|
||||||
|
@patch.object(TestAgent, "_invoke_run")
|
||||||
|
def test_agent_process_signal_exception(self, invoke_run, put_metric_mock):
|
||||||
|
spec = self._get_worker_spec(max_restarts=0)
|
||||||
|
agent = TestAgent(spec)
|
||||||
|
invoke_run.side_effect = SignalException(
|
||||||
|
"signal exception", sigval=signal.SIGTERM
|
||||||
|
)
|
||||||
|
with patch.object(agent, "_shutdown") as shutdown_mock:
|
||||||
|
with self.assertRaises(SignalException):
|
||||||
|
agent.run()
|
||||||
|
args, _ = shutdown_mock.call_args
|
||||||
|
self.assertEqual(signal.SIGTERM, args[0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import tempfile
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Union, Callable
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
@ -24,6 +24,7 @@ import torch.multiprocessing as mp
|
||||||
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
|
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
|
||||||
from torch.distributed.elastic.multiprocessing.api import (
|
from torch.distributed.elastic.multiprocessing.api import (
|
||||||
MultiprocessContext,
|
MultiprocessContext,
|
||||||
|
SignalException,
|
||||||
RunProcsResult,
|
RunProcsResult,
|
||||||
Std,
|
Std,
|
||||||
_validate_full_rank,
|
_validate_full_rank,
|
||||||
|
|
@ -173,6 +174,52 @@ def redirects_all() -> List[Std]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def bin(name: str):
|
||||||
|
dir = os.path.dirname(__file__)
|
||||||
|
return os.path.join(dir, "bin", name)
|
||||||
|
|
||||||
|
|
||||||
|
def wait_fn(wait_time: int = 300) -> None:
|
||||||
|
time.sleep(wait_time)
|
||||||
|
print("Finished waiting")
|
||||||
|
|
||||||
|
|
||||||
|
def start_processes_zombie_test(
|
||||||
|
idx: int,
|
||||||
|
entrypoint: Union[str, Callable],
|
||||||
|
mp_queue: mp.Queue,
|
||||||
|
log_dir: str,
|
||||||
|
nproc: int = 2,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Starts processes
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = {}
|
||||||
|
envs = {}
|
||||||
|
for idx in range(nproc):
|
||||||
|
args[idx] = ()
|
||||||
|
envs[idx] = {}
|
||||||
|
|
||||||
|
pc = start_processes(
|
||||||
|
name="zombie_test",
|
||||||
|
entrypoint=entrypoint,
|
||||||
|
args=args,
|
||||||
|
envs=envs,
|
||||||
|
log_dir=log_dir,
|
||||||
|
redirects=Std.NONE,
|
||||||
|
)
|
||||||
|
my_pid = os.getpid()
|
||||||
|
mp_queue.put(my_pid)
|
||||||
|
for child_pid in pc.pids().values():
|
||||||
|
mp_queue.put(child_pid)
|
||||||
|
|
||||||
|
try:
|
||||||
|
pc.wait(period=1, timeout=300)
|
||||||
|
except SignalException as e:
|
||||||
|
pc.close(e.sigval)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
TEST_WITH_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS,
|
TEST_WITH_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS,
|
||||||
"tests incompatible with tsan or asan",
|
"tests incompatible with tsan or asan",
|
||||||
|
|
@ -294,6 +341,19 @@ class StartProcessesTest(unittest.TestCase):
|
||||||
self.assertTrue(pc._stderr_tail.stopped())
|
self.assertTrue(pc._stderr_tail.stopped())
|
||||||
self.assertTrue(pc._stdout_tail.stopped())
|
self.assertTrue(pc._stdout_tail.stopped())
|
||||||
|
|
||||||
|
def test_subprocess_context_close(self):
|
||||||
|
pc = start_processes(
|
||||||
|
name="sleep",
|
||||||
|
entrypoint=bin("zombie_test.py"),
|
||||||
|
args={0: (1,)},
|
||||||
|
envs={0: {}},
|
||||||
|
log_dir=self.log_dir(),
|
||||||
|
)
|
||||||
|
|
||||||
|
pids = pc.pids()
|
||||||
|
pc.close()
|
||||||
|
self.assert_pids_noexist(pids)
|
||||||
|
|
||||||
def test_function_with_tensor(self):
|
def test_function_with_tensor(self):
|
||||||
for start_method in self._start_methods:
|
for start_method in self._start_methods:
|
||||||
pc = start_processes(
|
pc = start_processes(
|
||||||
|
|
@ -395,15 +455,11 @@ class StartProcessesTest(unittest.TestCase):
|
||||||
# start_processes as binary tests
|
# start_processes as binary tests
|
||||||
########################################
|
########################################
|
||||||
|
|
||||||
def bin(self, name: str):
|
|
||||||
dir = os.path.dirname(__file__)
|
|
||||||
return os.path.join(dir, "bin", name)
|
|
||||||
|
|
||||||
def test_binary_exit(self):
|
def test_binary_exit(self):
|
||||||
FAIL = 138
|
FAIL = 138
|
||||||
pc = start_processes(
|
pc = start_processes(
|
||||||
name="echo",
|
name="echo",
|
||||||
entrypoint=self.bin("echo1.py"),
|
entrypoint=bin("echo1.py"),
|
||||||
args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")},
|
args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")},
|
||||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||||
log_dir=self.log_dir(),
|
log_dir=self.log_dir(),
|
||||||
|
|
@ -429,7 +485,7 @@ class StartProcessesTest(unittest.TestCase):
|
||||||
def test_binary_raises(self):
|
def test_binary_raises(self):
|
||||||
pc = start_processes(
|
pc = start_processes(
|
||||||
name="echo",
|
name="echo",
|
||||||
entrypoint=self.bin("echo2.py"),
|
entrypoint=bin("echo2.py"),
|
||||||
args={0: ("--raises", "true", "foo"), 1: ("bar",)},
|
args={0: ("--raises", "true", "foo"), 1: ("bar",)},
|
||||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||||
log_dir=self.log_dir(),
|
log_dir=self.log_dir(),
|
||||||
|
|
@ -463,7 +519,7 @@ class StartProcessesTest(unittest.TestCase):
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
NO_MULTIPROCESSING_SPAWN,
|
NO_MULTIPROCESSING_SPAWN,
|
||||||
"Disabled for environments that \
|
"Disabled for environments that \
|
||||||
don't support multiprocessing with spawn start method",
|
don't support multiprocessing with spawn start method",
|
||||||
)
|
)
|
||||||
def test_multiprocessing_context_poll_raises_exception(self):
|
def test_multiprocessing_context_poll_raises_exception(self):
|
||||||
mp_context = MultiprocessContext(
|
mp_context = MultiprocessContext(
|
||||||
|
|
@ -540,7 +596,7 @@ class StartProcessesListTest(StartProcessesTest):
|
||||||
with self.subTest(redirs=redirs):
|
with self.subTest(redirs=redirs):
|
||||||
pc = start_processes(
|
pc = start_processes(
|
||||||
name="echo",
|
name="echo",
|
||||||
entrypoint=self.bin("echo1.py"),
|
entrypoint=bin("echo1.py"),
|
||||||
args={0: ("hello",), 1: ("hello",)},
|
args={0: ("hello",), 1: ("hello",)},
|
||||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||||
log_dir=self.log_dir(),
|
log_dir=self.log_dir(),
|
||||||
|
|
@ -572,7 +628,7 @@ class StartProcessesListTest(StartProcessesTest):
|
||||||
def test_binary_redirect_and_tee(self):
|
def test_binary_redirect_and_tee(self):
|
||||||
pc = start_processes(
|
pc = start_processes(
|
||||||
name="trainer",
|
name="trainer",
|
||||||
entrypoint=self.bin("echo1.py"),
|
entrypoint=bin("echo1.py"),
|
||||||
args={0: ("hello",), 1: ("world",)},
|
args={0: ("hello",), 1: ("world",)},
|
||||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||||
log_dir=self.log_dir(),
|
log_dir=self.log_dir(),
|
||||||
|
|
@ -631,7 +687,7 @@ class StartProcessesNotCITest(StartProcessesTest):
|
||||||
def test_binary_signal(self):
|
def test_binary_signal(self):
|
||||||
pc = start_processes(
|
pc = start_processes(
|
||||||
name="echo",
|
name="echo",
|
||||||
entrypoint=self.bin("echo3.py"),
|
entrypoint=bin("echo3.py"),
|
||||||
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
|
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
|
||||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||||
log_dir=self.log_dir(),
|
log_dir=self.log_dir(),
|
||||||
|
|
@ -755,6 +811,43 @@ class StartProcessesNotCITest(StartProcessesTest):
|
||||||
self.assertTrue(pc._stderr_tail.stopped())
|
self.assertTrue(pc._stderr_tail.stopped())
|
||||||
self.assertTrue(pc._stdout_tail.stopped())
|
self.assertTrue(pc._stdout_tail.stopped())
|
||||||
|
|
||||||
|
def test_no_zombie_process_binary(self):
|
||||||
|
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
||||||
|
for s in signals:
|
||||||
|
self._test_zombie_workflow(bin("zombie_test.py"), s)
|
||||||
|
|
||||||
|
def test_no_zombie_process_function(self):
|
||||||
|
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
||||||
|
for s in signals:
|
||||||
|
self._test_zombie_workflow(wait_fn, s)
|
||||||
|
|
||||||
|
def _test_zombie_workflow(
|
||||||
|
self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
|
||||||
|
) -> None:
|
||||||
|
mp_queue = mp.get_context("spawn").Queue()
|
||||||
|
child_nproc = 2
|
||||||
|
ctx = mp.spawn(
|
||||||
|
start_processes_zombie_test,
|
||||||
|
nprocs=1,
|
||||||
|
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
|
||||||
|
join=False,
|
||||||
|
)
|
||||||
|
total_processes = child_nproc + 1
|
||||||
|
pids = []
|
||||||
|
for _ in range(total_processes):
|
||||||
|
pids.append(mp_queue.get(timeout=120))
|
||||||
|
parent_pid = pids[0]
|
||||||
|
child_pids = pids[1:]
|
||||||
|
|
||||||
|
os.kill(parent_pid, signal.SIGTERM)
|
||||||
|
# Wait to give time for signal handlers to finish work
|
||||||
|
time.sleep(5)
|
||||||
|
for child_pid in child_pids:
|
||||||
|
# Killing parent should kill all children, we expect that each call to
|
||||||
|
# os.kill would raise OSError
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
os.kill(child_pid, 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
14
test/distributed/elastic/multiprocessing/bin/zombie_test.py
Executable file
14
test/distributed/elastic/multiprocessing/bin/zombie_test.py
Executable file
|
|
@ -0,0 +1,14 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time.sleep(600)
|
||||||
|
print("finished work")
|
||||||
|
|
@ -10,6 +10,7 @@ import abc
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
@ -24,7 +25,11 @@ import torch.distributed.elastic.utils.store as store_util
|
||||||
from torch.distributed import Store
|
from torch.distributed import Store
|
||||||
from torch.distributed.elastic.events import Event, EventSource, record
|
from torch.distributed.elastic.events import Event, EventSource, record
|
||||||
from torch.distributed.elastic.metrics import prof, put_metric
|
from torch.distributed.elastic.metrics import prof, put_metric
|
||||||
from torch.distributed.elastic.multiprocessing import ProcessFailure, Std
|
from torch.distributed.elastic.multiprocessing import (
|
||||||
|
ProcessFailure,
|
||||||
|
Std,
|
||||||
|
SignalException,
|
||||||
|
)
|
||||||
from torch.distributed.elastic.utils.logging import get_logger
|
from torch.distributed.elastic.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -488,9 +493,12 @@ class SimpleElasticAgent(ElasticAgent):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _shutdown(self) -> None:
|
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
|
||||||
"""
|
"""
|
||||||
Cleans up any resources that were allocated during the agent's work.
|
Cleans up any resources that were allocated during the agent's work.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
death_sig: Signal to send to the child process, SIGTERM is default
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
@ -696,16 +704,23 @@ class SimpleElasticAgent(ElasticAgent):
|
||||||
@prof
|
@prof
|
||||||
def run(self, role: str = DEFAULT_ROLE) -> RunResult:
|
def run(self, role: str = DEFAULT_ROLE) -> RunResult:
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
|
shutdown_called: bool = False
|
||||||
try:
|
try:
|
||||||
result = self._invoke_run(role)
|
result = self._invoke_run(role)
|
||||||
self._total_execution_time = int(time.monotonic() - start_time)
|
self._total_execution_time = int(time.monotonic() - start_time)
|
||||||
self._record_metrics(result)
|
self._record_metrics(result)
|
||||||
self._record_worker_events(result)
|
self._record_worker_events(result)
|
||||||
return result
|
return result
|
||||||
|
except SignalException as e:
|
||||||
|
log.warning(f"Received {e.sigval} death signal, shutting down workers")
|
||||||
|
self._shutdown(e.sigval)
|
||||||
|
shutdown_called = True
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
if not shutdown_called:
|
||||||
|
self._shutdown()
|
||||||
# record the execution time in case there were any exceptions during run.
|
# record the execution time in case there were any exceptions during run.
|
||||||
self._total_execution_time = int(time.monotonic() - start_time)
|
self._total_execution_time = int(time.monotonic() - start_time)
|
||||||
self._shutdown()
|
|
||||||
|
|
||||||
def get_agent_status_event(self, state: WorkerState) -> Event:
|
def get_agent_status_event(self, state: WorkerState) -> Event:
|
||||||
raw_error = traceback.format_exc() if state == WorkerState.FAILED else None
|
raw_error = traceback.format_exc() if state == WorkerState.FAILED else None
|
||||||
|
|
@ -891,6 +906,9 @@ class SimpleElasticAgent(ElasticAgent):
|
||||||
log.info(
|
log.info(
|
||||||
f"Done waiting for other agents. Elapsed: {time.time() - start} seconds"
|
f"Done waiting for other agents. Elapsed: {time.time() - start} seconds"
|
||||||
)
|
)
|
||||||
|
except SignalException as e:
|
||||||
|
log.warn(f"Got termination signal: {e.sigval}")
|
||||||
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception(
|
log.exception(
|
||||||
f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds"
|
f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds"
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import signal
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
|
@ -183,9 +184,9 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||||
|
|
||||||
return self._pcontext.pids()
|
return self._pcontext.pids()
|
||||||
|
|
||||||
def _shutdown(self) -> None:
|
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
|
||||||
if self._pcontext:
|
if self._pcontext:
|
||||||
self._pcontext.close()
|
self._pcontext.close(death_sig)
|
||||||
|
|
||||||
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
||||||
# `torch.distributed.elastic.metrics.prof`.
|
# `torch.distributed.elastic.metrics.prof`.
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,7 @@ from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
|
||||||
ProcessFailure,
|
ProcessFailure,
|
||||||
RunProcsResult,
|
RunProcsResult,
|
||||||
Std,
|
Std,
|
||||||
|
SignalException,
|
||||||
SubprocessContext,
|
SubprocessContext,
|
||||||
_validate_full_rank,
|
_validate_full_rank,
|
||||||
to_map,
|
to_map,
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from contextlib import AbstractContextManager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import IntFlag
|
from enum import IntFlag
|
||||||
from multiprocessing import synchronize
|
from multiprocessing import synchronize
|
||||||
|
from types import FrameType
|
||||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
@ -35,6 +36,50 @@ IS_MACOS = sys.platform == "darwin"
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SignalException(Exception):
|
||||||
|
"""
|
||||||
|
Exception is raised inside the torchelastic agent process by the termination handler
|
||||||
|
if the death signal got received by the process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, msg: str, sigval: signal.Signals) -> None:
|
||||||
|
super().__init__(msg)
|
||||||
|
self.sigval = sigval
|
||||||
|
|
||||||
|
|
||||||
|
def _terminate_process_handler(signum: int, frame: FrameType) -> None:
|
||||||
|
"""Termination handler that raises exceptions on the main process.
|
||||||
|
|
||||||
|
When the process receives death signal(SIGTERM, SIGINT), this termination handler will
|
||||||
|
be invoked. It raises the ``SignalException`` exception that should be processed by the
|
||||||
|
user code. Python does not terminate process after the termination handler is finished,
|
||||||
|
so the exception should not be silently ignored, otherwise the process will never
|
||||||
|
be terminated.
|
||||||
|
"""
|
||||||
|
sigval = signal.Signals(signum)
|
||||||
|
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_kill_signal() -> signal.Signals:
|
||||||
|
"""
|
||||||
|
Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.
|
||||||
|
"""
|
||||||
|
if IS_WINDOWS:
|
||||||
|
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
|
||||||
|
else:
|
||||||
|
return signal.SIGKILL
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_signal() -> signal.Signals:
|
||||||
|
"""
|
||||||
|
Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.
|
||||||
|
"""
|
||||||
|
if IS_WINDOWS:
|
||||||
|
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
|
||||||
|
else:
|
||||||
|
return signal.SIGTERM
|
||||||
|
|
||||||
|
|
||||||
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
|
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
|
||||||
actual_keys = set(d.keys())
|
actual_keys = set(d.keys())
|
||||||
expected_keys = set(range(nprocs))
|
expected_keys = set(range(nprocs))
|
||||||
|
|
@ -185,6 +230,11 @@ class PContext(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Start processes using parameters defined in the constructor.
|
Start processes using parameters defined in the constructor.
|
||||||
"""
|
"""
|
||||||
|
signal.signal(signal.SIGTERM, _terminate_process_handler)
|
||||||
|
signal.signal(signal.SIGINT, _terminate_process_handler)
|
||||||
|
if not IS_WINDOWS:
|
||||||
|
signal.signal(signal.SIGHUP, _terminate_process_handler)
|
||||||
|
signal.signal(signal.SIGQUIT, _terminate_process_handler)
|
||||||
self._start()
|
self._start()
|
||||||
self._stdout_tail.start()
|
self._stdout_tail.start()
|
||||||
self._stderr_tail.start()
|
self._stderr_tail.start()
|
||||||
|
|
@ -214,6 +264,23 @@ class PContext(abc.ABC):
|
||||||
on timeout expiry. Negative timeout values are interpreted as "wait-forever".
|
on timeout expiry. Negative timeout values are interpreted as "wait-forever".
|
||||||
A timeout value of zero simply queries the status of the processes (e.g. equivalent
|
A timeout value of zero simply queries the status of the processes (e.g. equivalent
|
||||||
to a poll).
|
to a poll).
|
||||||
|
|
||||||
|
..note: Multiprocesing library registers SIGTERM and SIGINT signal handlers that raise
|
||||||
|
``SignalException`` when the signals received. It is up to the consumer of the code
|
||||||
|
to properly handle the exception. It is important not to swallow the exception otherwise
|
||||||
|
the process would not terminate. Example of the typical workflow can be:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
pc = start_processes(...)
|
||||||
|
try:
|
||||||
|
pc.wait(1)
|
||||||
|
.. do some other work
|
||||||
|
except SignalException as e:
|
||||||
|
pc.shutdown(e.sigval, timeout=30)
|
||||||
|
|
||||||
|
If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating
|
||||||
|
received signal. If child processes will not terminate in the timeout time, the process will send
|
||||||
|
the SIGKILL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if timeout == 0:
|
if timeout == 0:
|
||||||
|
|
@ -239,15 +306,28 @@ class PContext(abc.ABC):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _close(self) -> None:
|
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
|
||||||
r"""
|
r"""
|
||||||
Terminates all processes managed by this context and cleans up any
|
Terminates all processes managed by this context and cleans up any
|
||||||
meta resources (e.g. redirect, error_file files).
|
meta resources (e.g. redirect, error_file files).
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(
|
||||||
self._close()
|
self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Terminates all processes managed by this context and cleans up any
|
||||||
|
meta resources (e.g. redirect, error_file files).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
death_sig: Death signal to terminate porcesses.
|
||||||
|
timeout: Time to wait for processes to finish, if process is
|
||||||
|
still alive after this time, it will be terminated via SIGKILL.
|
||||||
|
"""
|
||||||
|
if not death_sig:
|
||||||
|
death_sig = _get_default_signal()
|
||||||
|
self._close(death_sig=death_sig, timeout=timeout)
|
||||||
if self._stdout_tail:
|
if self._stdout_tail:
|
||||||
self._stdout_tail.stop()
|
self._stdout_tail.stop()
|
||||||
if self._stderr_tail:
|
if self._stderr_tail:
|
||||||
|
|
@ -447,11 +527,35 @@ class MultiprocessContext(PContext):
|
||||||
assert self._pc is not None # assertion for mypy type checking
|
assert self._pc is not None # assertion for mypy type checking
|
||||||
return {local_rank: pid for local_rank, pid in enumerate(self._pc.pids())}
|
return {local_rank: pid for local_rank, pid in enumerate(self._pc.pids())}
|
||||||
|
|
||||||
def _close(self) -> None:
|
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
|
||||||
if self._pc:
|
if not self._pc:
|
||||||
for proc in self._pc.processes:
|
return
|
||||||
proc.terminate()
|
log.warning(f"Closing processes via signal {death_sig}")
|
||||||
proc.join()
|
for proc in self._pc.processes:
|
||||||
|
try:
|
||||||
|
os.kill(proc.pid, death_sig)
|
||||||
|
except ProcessLookupError:
|
||||||
|
# If the process exited because of some reason,
|
||||||
|
# `ProcessLookupError` will be rasied, it is safe to ignore it.
|
||||||
|
pass
|
||||||
|
end = time.monotonic() + timeout
|
||||||
|
for proc in self._pc.processes:
|
||||||
|
time_to_wait = end - time.monotonic()
|
||||||
|
if time_to_wait <= 0:
|
||||||
|
break
|
||||||
|
proc.join(time_to_wait)
|
||||||
|
log.warning(
|
||||||
|
f"Unable to shutdown processes via {death_sig}, forcefully exitting via {_get_kill_signal()}"
|
||||||
|
)
|
||||||
|
for proc in self._pc.processes:
|
||||||
|
if proc.is_alive():
|
||||||
|
try:
|
||||||
|
os.kill(proc.pid, _get_kill_signal())
|
||||||
|
except ProcessLookupError:
|
||||||
|
# If the process exited because of some reason,
|
||||||
|
# `ProcessLookupError` will be rasied, it is safe to ignore it.
|
||||||
|
pass
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
|
||||||
class SubprocessHandler:
|
class SubprocessHandler:
|
||||||
|
|
@ -465,7 +569,6 @@ class SubprocessHandler:
|
||||||
entrypoint: str,
|
entrypoint: str,
|
||||||
args: Tuple,
|
args: Tuple,
|
||||||
env: Dict[str, str],
|
env: Dict[str, str],
|
||||||
preexec_fn: Optional[Callable],
|
|
||||||
stdout: str,
|
stdout: str,
|
||||||
stderr: str,
|
stderr: str,
|
||||||
):
|
):
|
||||||
|
|
@ -476,46 +579,29 @@ class SubprocessHandler:
|
||||||
env_vars.update(env)
|
env_vars.update(env)
|
||||||
|
|
||||||
args_str = (entrypoint, *[str(e) for e in args])
|
args_str = (entrypoint, *[str(e) for e in args])
|
||||||
self.proc: subprocess.Popen = self._popen(args_str, env_vars, preexec_fn)
|
self.proc: subprocess.Popen = self._popen(args_str, env_vars)
|
||||||
|
|
||||||
def _popen(
|
|
||||||
self, args: Tuple, env: Dict[str, str], preexec_fn: Optional[Callable]
|
|
||||||
) -> subprocess.Popen:
|
|
||||||
if IS_WINDOWS:
|
|
||||||
# Reset preexec_fn on windows, since windows does not support it
|
|
||||||
preexec_fn = None
|
|
||||||
|
|
||||||
|
def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen:
|
||||||
return subprocess.Popen(
|
return subprocess.Popen(
|
||||||
# pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
|
# pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
|
||||||
# _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
|
# _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
|
||||||
# `Tuple[str, *Tuple[Any, ...]]`.
|
# `Tuple[str, *Tuple[Any, ...]]`.
|
||||||
args=args,
|
args=args,
|
||||||
env=env,
|
env=env,
|
||||||
preexec_fn=preexec_fn,
|
|
||||||
stdout=self._stdout,
|
stdout=self._stdout,
|
||||||
stderr=self._stderr,
|
stderr=self._stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
def close(self):
|
def close(self, death_sig: Optional[signal.Signals] = None) -> None:
|
||||||
self.proc.terminate()
|
if not death_sig:
|
||||||
self.proc.wait()
|
death_sig = _get_default_signal()
|
||||||
|
self.proc.send_signal(death_sig)
|
||||||
if self._stdout:
|
if self._stdout:
|
||||||
self._stdout.close()
|
self._stdout.close()
|
||||||
if self._stderr:
|
if self._stderr:
|
||||||
self._stderr.close()
|
self._stderr.close()
|
||||||
|
|
||||||
|
|
||||||
def _pr_set_pdeathsig() -> None:
|
|
||||||
"""
|
|
||||||
Sets PR_SET_PDEATHSIG to ensure a child process is
|
|
||||||
terminated appropriately.
|
|
||||||
|
|
||||||
See http://stackoverflow.com/questions/1884941/ for more information.
|
|
||||||
For libc.so.6 read http://www.linux-m68k.org/faq/glibcinfo.html
|
|
||||||
"""
|
|
||||||
mp._prctl_pr_set_pdeathsig(signal.SIGTERM) # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
|
|
||||||
class SubprocessContext(PContext):
|
class SubprocessContext(PContext):
|
||||||
"""
|
"""
|
||||||
``PContext`` holding worker processes invoked as a binary.
|
``PContext`` holding worker processes invoked as a binary.
|
||||||
|
|
@ -560,7 +646,6 @@ class SubprocessContext(PContext):
|
||||||
entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str
|
entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str
|
||||||
args=self.args[local_rank],
|
args=self.args[local_rank],
|
||||||
env=self.envs[local_rank],
|
env=self.envs[local_rank],
|
||||||
preexec_fn=_pr_set_pdeathsig,
|
|
||||||
stdout=self.stdouts[local_rank],
|
stdout=self.stdouts[local_rank],
|
||||||
stderr=self.stderrs[local_rank],
|
stderr=self.stderrs[local_rank],
|
||||||
)
|
)
|
||||||
|
|
@ -616,7 +701,21 @@ class SubprocessContext(PContext):
|
||||||
for local_rank, sh in self.subprocess_handlers.items()
|
for local_rank, sh in self.subprocess_handlers.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def _close(self) -> None:
|
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
|
||||||
if self.subprocess_handlers:
|
if not self.subprocess_handlers:
|
||||||
for handler in self.subprocess_handlers.values():
|
return
|
||||||
handler.close()
|
log.warning(f"Sending processes {death_sig}")
|
||||||
|
for handler in self.subprocess_handlers.values():
|
||||||
|
handler.close(death_sig=death_sig)
|
||||||
|
end = time.monotonic() + timeout
|
||||||
|
for handler in self.subprocess_handlers.values():
|
||||||
|
time_to_wait = end - time.monotonic()
|
||||||
|
if time_to_wait <= 0:
|
||||||
|
break
|
||||||
|
handler.proc.wait(time_to_wait)
|
||||||
|
log.warning(
|
||||||
|
f"Unable to shutdown processes via {death_sig}, forcefully exitting via {_get_kill_signal()}"
|
||||||
|
)
|
||||||
|
for handler in self.subprocess_handlers.values():
|
||||||
|
handler.close(death_sig=_get_kill_signal())
|
||||||
|
handler.proc.wait()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user