mirror of
https://github.com/zebrajr/ansible.git
synced 2025-12-06 00:19:48 +01:00
Fix signal propagation (#85907)
This commit is contained in:
parent
9ee667030f
commit
5a9afe4409
3
changelogs/fragments/fix-signal-propagation.yml
Normal file
3
changelogs/fragments/fix-signal-propagation.yml
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
bugfixes:
|
||||
- SIGINT/SIGTERM Handling - Make SIGINT/SIGTERM handling more robust by splitting concerns
|
||||
between forks and the parent.
|
||||
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import io
|
||||
import os
|
||||
import signal
|
||||
|
|
@ -103,11 +104,19 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin
|
|||
self._cliargs = cliargs
|
||||
|
||||
def _term(self, signum, frame) -> None:
|
||||
"""
|
||||
terminate the process group created by calling setsid when
|
||||
a terminate signal is received by the fork
|
||||
"""
|
||||
"""In child termination when notified by the parent"""
|
||||
signal.signal(signum, signal.SIG_DFL)
|
||||
|
||||
try:
|
||||
os.killpg(self.pid, signum)
|
||||
os.kill(self.pid, signum)
|
||||
except OSError as e:
|
||||
if e.errno != errno.ESRCH:
|
||||
signame = signal.strsignal(signum)
|
||||
display.error(f'Unable to send {signame} to child[{self.pid}]: {e}')
|
||||
|
||||
# fallthrough, if we are still here, just die
|
||||
os._exit(1)
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
|
|
@ -121,11 +130,6 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin
|
|||
# FUTURE: this lock can be removed once a more generalized pre-fork thread pause is in place
|
||||
with display._lock:
|
||||
super(WorkerProcess, self).start()
|
||||
# Since setsid is called later, if the worker is termed
|
||||
# it won't term the new process group
|
||||
# register a handler to propagate the signal
|
||||
signal.signal(signal.SIGTERM, self._term)
|
||||
signal.signal(signal.SIGINT, self._term)
|
||||
|
||||
def _hard_exit(self, e: str) -> t.NoReturn:
|
||||
"""
|
||||
|
|
@ -170,7 +174,6 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin
|
|||
# to give better errors, and to prevent fd 0 reuse
|
||||
sys.stdin.close()
|
||||
except Exception as e:
|
||||
display.debug(f'Could not detach from stdio: {traceback.format_exc()}')
|
||||
display.error(f'Could not detach from stdio: {e}')
|
||||
os._exit(1)
|
||||
|
||||
|
|
@ -187,6 +190,9 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin
|
|||
# Set the queue on Display so calls to Display.display are proxied over the queue
|
||||
display.set_queue(self._final_q)
|
||||
self._detach()
|
||||
# propagate signals
|
||||
signal.signal(signal.SIGINT, self._term)
|
||||
signal.signal(signal.SIGTERM, self._term)
|
||||
try:
|
||||
with _task.TaskContext(self._task):
|
||||
return self._run()
|
||||
|
|
|
|||
|
|
@ -18,8 +18,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
|
|
@ -185,8 +187,48 @@ class TaskQueueManager:
|
|||
# plugins for inter-process locking.
|
||||
self._connection_lockfile = tempfile.TemporaryFile()
|
||||
|
||||
self._workers: list[WorkerProcess | None] = []
|
||||
|
||||
# signal handlers to propagate signals to workers
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
def _initialize_processes(self, num: int) -> None:
|
||||
self._workers: list[WorkerProcess | None] = [None] * num
|
||||
# mutable update to ensure the reference stays the same
|
||||
self._workers[:] = [None] * num
|
||||
|
||||
def _signal_handler(self, signum, frame) -> None:
|
||||
"""
|
||||
terminate all running process groups created as a result of calling
|
||||
setsid from within a WorkerProcess.
|
||||
|
||||
Since the children become process leaders, signals will not
|
||||
automatically propagate to them.
|
||||
"""
|
||||
signal.signal(signum, signal.SIG_DFL)
|
||||
|
||||
for worker in self._workers:
|
||||
if worker is None or not worker.is_alive():
|
||||
continue
|
||||
if worker.pid:
|
||||
try:
|
||||
# notify workers
|
||||
os.kill(worker.pid, signum)
|
||||
except OSError as e:
|
||||
if e.errno != errno.ESRCH:
|
||||
signame = signal.strsignal(signum)
|
||||
display.error(f'Unable to send {signame} to child[{worker.pid}]: {e}')
|
||||
|
||||
if signum == signal.SIGINT:
|
||||
# Defer to CLI handling
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
pid = os.getpid()
|
||||
try:
|
||||
os.kill(pid, signum)
|
||||
except OSError as e:
|
||||
signame = signal.strsignal(signum)
|
||||
display.error(f'Unable to send {signame} to {pid}: {e}')
|
||||
|
||||
def load_callbacks(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2,4 +2,5 @@ needs/ssh
|
|||
shippable/posix/group3
|
||||
needs/target/connection
|
||||
needs/target/setup_test_user
|
||||
needs/target/test_utils
|
||||
setup/always/setup_passlib_controller # required for setup_test_user
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ if command -v sshpass > /dev/null; then
|
|||
# ansible with timeout. If we time out, our custom prompt was successfully
|
||||
# searched for. It's a weird way of doing things, but it does ensure
|
||||
# that the flag gets passed to sshpass.
|
||||
timeout 5 ansible -m ping \
|
||||
../test_utils/scripts/timeout.py 5 -- ansible -m ping \
|
||||
-e ansible_connection=ssh \
|
||||
-e ansible_ssh_password_mechanism=sshpass \
|
||||
-e ansible_sshpass_prompt=notThis: \
|
||||
|
|
|
|||
3
test/integration/targets/signal_propagation/aliases
Normal file
3
test/integration/targets/signal_propagation/aliases
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
shippable/posix/group4
|
||||
context/controller
|
||||
needs/target/test_utils
|
||||
14
test/integration/targets/signal_propagation/inventory
Normal file
14
test/integration/targets/signal_propagation/inventory
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
localhost0
|
||||
localhost1
|
||||
localhost2
|
||||
localhost3
|
||||
localhost4
|
||||
localhost5
|
||||
localhost6
|
||||
localhost7
|
||||
localhost8
|
||||
localhost9
|
||||
|
||||
[all:vars]
|
||||
ansible_connection=local
|
||||
ansible_python_interpreter={{ansible_playbook_python}}
|
||||
21
test/integration/targets/signal_propagation/runme.sh
Executable file
21
test/integration/targets/signal_propagation/runme.sh
Executable file
|
|
@ -0,0 +1,21 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -x
|
||||
|
||||
../test_utils/scripts/timeout.py -s SIGINT 3 -- \
|
||||
ansible all -i inventory -m debug -a 'msg={{lookup("pipe", "sleep 33")}}' -f 10
|
||||
if [[ "$?" != "124" ]]; then
|
||||
echo "Process was not terminated due to timeout"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# a short sleep to let processes die
|
||||
sleep 2
|
||||
|
||||
sleeps="$(pgrep -alf 'sleep\ 33')"
|
||||
rc="$?"
|
||||
if [[ "$rc" == "0" ]]; then
|
||||
echo "Found lingering processes:"
|
||||
echo "$sleeps"
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -2,21 +2,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def signal_type(v: str) -> signal.Signals:
|
||||
if v.isdecimal():
|
||||
return signal.Signals(int(v))
|
||||
if not v.startswith('SIG'):
|
||||
v = f'SIG{v}'
|
||||
return getattr(signal.Signals, v)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('duration', type=int)
|
||||
parser.add_argument('--signal', '-s', default=signal.SIGTERM, type=signal_type)
|
||||
parser.add_argument('command', nargs='+')
|
||||
args = parser.parse_args()
|
||||
|
||||
p: subprocess.Popen | None = None
|
||||
try:
|
||||
p = subprocess.run(
|
||||
' '.join(args.command),
|
||||
shell=True,
|
||||
timeout=args.duration,
|
||||
check=False,
|
||||
)
|
||||
p = subprocess.Popen(args.command)
|
||||
p.wait(timeout=args.duration)
|
||||
sys.exit(p.returncode)
|
||||
except subprocess.TimeoutExpired:
|
||||
if p and p.poll() is None:
|
||||
p.send_signal(args.signal)
|
||||
p.wait()
|
||||
sys.exit(124)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user