mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: This fixes at partly a recurrent problem when using everstore data input (or any other data input with multiprocessing). If the main process dies violently, the child processes are not killed. One cause for this was when using the TimeoutGuard(), as it called os._exit(1) that prevents any cleanup happening. I changed it to send SIGINT signal to the PID, and if in 10 secs the process is still living, calling os._exit(1). In my tests, this works well. Did some other cleanup: - improved logging of inputs/sec in data_workers - removed redundant atexit() handling as the multiprocessing pool does it itself Differential Revision: D4602550 fbshipit-source-id: 64d4526a2a3625d163d23f078286e719d56998f4
68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import contextlib
|
|
import threading
|
|
import os
|
|
import time
|
|
import signal
|
|
import logging
|
|
|
|
|
|
'''
|
|
Sometimes CUDA devices can get stuck, 'deadlock'. In this case it is often
|
|
better just the kill the process automatically. Use this guard to set a
|
|
maximum timespan for a python call, such as RunNet(). If it does not complete
|
|
in time, process is killed.
|
|
|
|
Example usage:
|
|
with timeout_guard.CompleteInTimeOrDie(10.0):
|
|
core.RunNet(...)
|
|
'''
|
|
|
|
|
|
class WatcherThread(threading.Thread):
|
|
|
|
def __init__(self, timeout_secs):
|
|
threading.Thread.__init__(self)
|
|
self.timeout_secs = timeout_secs
|
|
self.completed = False
|
|
self.condition = threading.Condition()
|
|
self.daemon = True
|
|
|
|
def run(self):
|
|
started = time.time()
|
|
self.condition.acquire()
|
|
while time.time() - started < self.timeout_secs and not self.completed:
|
|
self.condition.wait(self.timeout_secs - (time.time() - started))
|
|
self.condition.release()
|
|
if not self.completed:
|
|
log = logging.getLogger("timeout_guard")
|
|
log.error("Call did not finish in time. Timeout:{}s".format(
|
|
self.timeout_secs
|
|
))
|
|
|
|
# First try dying cleanly, but in 10 secs, exit properly
|
|
def forcequit():
|
|
time.sleep(10.0)
|
|
log.error("Process did not terminate cleanly in 10 s, forcing")
|
|
os._exit(1)
|
|
forcet = threading.Thread(target=forcequit, args=())
|
|
forcet.daemon = True
|
|
forcet.start()
|
|
os.kill(os.getpid(), signal.SIGINT)
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def CompleteInTimeOrDie(timeout_secs):
|
|
watcher = WatcherThread(timeout_secs)
|
|
watcher.start()
|
|
yield
|
|
watcher.completed = True
|
|
watcher.condition.acquire()
|
|
watcher.condition.notify()
|
|
watcher.condition.release()
|