bugfix - timelapse image loader

multithreading.py - typing + docs
This commit is contained in:
torzdf 2022-08-25 12:38:05 +01:00
parent 9e503bdaa2
commit 326110f09d
4 changed files with 188 additions and 50 deletions

View File

@ -0,0 +1,7 @@
multithreading module
=====================
.. automodule:: lib.multithreading
:members:
:undoc-members:
:show-inheritance:

View File

@ -7,8 +7,13 @@ from multiprocessing import cpu_count
import queue as Queue import queue as Queue
import sys import sys
import threading import threading
from types import TracebackType
from typing import Any, Callable, Dict, Generator, List, Tuple, Type, Optional, Set, Union
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_ErrorType = Optional[Union[Tuple[Type[BaseException], BaseException, TracebackType],
Tuple[Any, Any, Any]]]
_THREAD_NAMES: Set[str] = set()
def total_cpus(): def total_cpus():
@ -16,22 +21,76 @@ def total_cpus():
return cpu_count() return cpu_count()
class FSThread(threading.Thread): def _get_name(name: str) -> str:
""" Subclass of thread that passes errors back to parent """ """ Obtain a unique name for a thread
def __init__(self, group=None, target=None, name=None, # pylint: disable=too-many-arguments
args=(), kwargs=None, *, daemon=None):
super().__init__(group=group, target=target, name=name,
args=args, kwargs=kwargs, daemon=daemon)
self.err = None
def check_and_raise_error(self): Parameters
""" Checks for errors in thread and raises them in caller """ ----------
name: str
The requested name
Returns
-------
str
The request name with "_#" appended (# being an integer) making the name unique
"""
idx = 0
real_name = name
while True:
if real_name in _THREAD_NAMES:
real_name = f"{name}_{idx}"
idx += 1
continue
_THREAD_NAMES.add(real_name)
return real_name
class FSThread(threading.Thread):
""" Subclass of thread that passes errors back to parent
Parameters
----------
target: callable object, Optional
The callable object to be invoked by the run() method. If ``None`` nothing is called.
Default: ``None``
name: str, optional
The thread name. if ``None`` a unique name is constructed of the form "Thread-N" where N
is a small decimal number. Default: ``None``
args: tuple
The argument tuple for the target invocation. Default: ().
kwargs: dict
keyword arguments for the target invocation. Default: {}.
"""
_target: Callable
_args: Tuple
_kwargs: Dict[str, Any]
_name: str
def __init__(self,
target: Optional[Callable] = None,
name: Optional[str] = None,
args: Tuple = (),
kwargs: Dict[str, Any] = None,
*,
daemon: Optional[bool] = None) -> None:
super().__init__(target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
self.err: _ErrorType = None
def check_and_raise_error(self) -> None:
""" Checks for errors in thread and raises them in caller.
Raises
------
Error
Re-raised error from within the thread
"""
if not self.err: if not self.err:
return return
logger.debug("Thread error caught: %s", self.err) logger.debug("Thread error caught: %s", self.err)
raise self.err[1].with_traceback(self.err[2]) raise self.err[1].with_traceback(self.err[2])
def run(self): def run(self) -> None:
""" Runs the target, reraising any errors from within the thread in the caller. """
try: try:
if self._target: if self._target:
self._target(*self._args, **self._kwargs) self._target(*self._args, **self._kwargs)
@ -45,53 +104,85 @@ class FSThread(threading.Thread):
class MultiThread(): class MultiThread():
""" Threading for IO heavy ops """ Threading for IO heavy ops. Catches errors in thread and rethrows to parent.
Catches errors in thread and rethrows to parent """
def __init__(self, target, *args, thread_count=1, name=None, **kwargs): Parameters
self._name = name if name else target.__name__ ----------
target: callable object
The callable object to be invoked by the run() method.
args: tuple
The argument tuple for the target invocation. Default: ().
thread_count: int, optional
The number of threads to use. Default: 1
name: str, optional
The thread name. if ``None`` a unique name is constructed of the form {target.__name__}_N
where N is an incrementing integer. Default: ``None``
kwargs: dict
keyword arguments for the target invocation. Default: {}.
"""
def __init__(self,
target: Callable,
*args,
thread_count: int = 1,
name: Optional[str] = None,
**kwargs) -> None:
self._name = _get_name(name if name else target.__name__)
logger.debug("Initializing %s: (target: '%s', thread_count: %s)", logger.debug("Initializing %s: (target: '%s', thread_count: %s)",
self.__class__.__name__, self._name, thread_count) self.__class__.__name__, self._name, thread_count)
logger.trace("args: %s, kwargs: %s", args, kwargs) logger.trace("args: %s, kwargs: %s", args, kwargs) # type:ignore
self.daemon = True self.daemon = True
self._thread_count = thread_count self._thread_count = thread_count
self._threads = list() self._threads: List[FSThread] = []
self._target = target self._target = target
self._args = args self._args = args
self._kwargs = kwargs self._kwargs = kwargs
logger.debug("Initialized %s: '%s'", self.__class__.__name__, self._name) logger.debug("Initialized %s: '%s'", self.__class__.__name__, self._name)
@property @property
def has_error(self): def has_error(self) -> bool:
""" Return true if a thread has errored, otherwise false """ """ bool: ``True`` if a thread has errored, otherwise ``False`` """
return any(thread.err for thread in self._threads) return any(thread.err for thread in self._threads)
@property @property
def errors(self): def errors(self) -> List[_ErrorType]:
""" Return a list of thread errors """ """ list: List of thread error values """
return [thread.err for thread in self._threads if thread.err] return [thread.err for thread in self._threads if thread.err]
@property @property
def name(self): def name(self) -> str:
""" Return thread name """ """ :str: The name of the thread """
return self._name return self._name
def check_and_raise_error(self): def check_and_raise_error(self) -> None:
""" Checks for errors in thread and raises them in caller """ """ Checks for errors in thread and raises them in caller.
Raises
------
Error
Re-raised error from within the thread
"""
if not self.has_error: if not self.has_error:
return return
logger.debug("Thread error caught: %s", self.errors) logger.debug("Thread error caught: %s", self.errors)
error = self.errors[0] error = self.errors[0]
assert error is not None
raise error[1].with_traceback(error[2]) raise error[1].with_traceback(error[2])
def is_alive(self): def is_alive(self) -> bool:
""" Return true if any thread is alive else false """ """ Check if any threads are still alive
Returns
-------
bool
``True`` if any threads are alive. ``False`` if no threads are alive
"""
return any(thread.is_alive() for thread in self._threads) return any(thread.is_alive() for thread in self._threads)
def start(self): def start(self) -> None:
""" Start a thread with the given method and args """ """ Start all the threads for the given method, args and kwargs """
logger.debug("Starting thread(s): '%s'", self._name) logger.debug("Starting thread(s): '%s'", self._name)
for idx in range(self._thread_count): for idx in range(self._thread_count):
name = "{}_{}".format(self._name, idx) name = self._name if self._thread_count == 1 else f"{self._name}_{idx}"
logger.debug("Starting thread %s of %s: '%s'", logger.debug("Starting thread %s of %s: '%s'",
idx + 1, self._thread_count, name) idx + 1, self._thread_count, name)
thread = FSThread(name=name, thread = FSThread(name=name,
@ -103,13 +194,18 @@ class MultiThread():
self._threads.append(thread) self._threads.append(thread)
logger.debug("Started all threads '%s': %s", self._name, len(self._threads)) logger.debug("Started all threads '%s': %s", self._name, len(self._threads))
def completed(self): def completed(self) -> bool:
""" Return False if there are any alive threads else True """ """ Check if all threads have completed
Returns
-------
``True`` if all threads have completed otherwise ``False``
"""
retval = all(not thread.is_alive() for thread in self._threads) retval = all(not thread.is_alive() for thread in self._threads)
logger.debug(retval) logger.debug(retval)
return retval return retval
def join(self): def join(self) -> None:
""" Join the running threads, catching and re-raising any errors """ """ Join the running threads, catching and re-raising any errors """
logger.debug("Joining Threads: '%s'", self._name) logger.debug("Joining Threads: '%s'", self._name)
for thread in self._threads: for thread in self._threads:
@ -123,24 +219,53 @@ class MultiThread():
class BackgroundGenerator(MultiThread): class BackgroundGenerator(MultiThread):
""" Run a queue in the background. From: """ Run a task in the background background and queue data for consumption
https://stackoverflow.com/questions/7323664/ """
# See below why prefetch count is flawed Parameters
def __init__(self, generator, prefetch=1, thread_count=2, ----------
queue=None, args=None, kwargs=None): generator: iterable
# pylint:disable=too-many-arguments The generator to run in the background
super().__init__(target=self._run, thread_count=thread_count) prefetch, int, optional
self.queue = queue or Queue.Queue(prefetch) The number of items to pre-fetch from the generator before blocking (see Notes). Default: 1
name: str, optional
The thread name. if ``None`` a unique name is constructed of the form
{generator.__name__}_N where N is an incrementing integer. Default: ``None``
args: tuple, Optional
The argument tuple for generator invocation. Default: ``None``.
kwargs: dict, Optional
keyword arguments for the generator invocation. Default: ``None``.
Notes
-----
Putting to the internal queue only blocks if put is called while queue has already
reached max size. Therefore this means prefetch is actually 1 more than the parameter
supplied (N in the queue, one waiting for insertion)
References
----------
https://stackoverflow.com/questions/7323664/
"""
def __init__(self,
generator: Callable,
prefetch: int = 1,
name: Optional[str] = None,
args: Optional[Tuple] = None,
kwargs: Optional[Dict[str, Any]] = None) -> None:
super().__init__(name=name, target=self._run)
self.queue: Queue.Queue = Queue.Queue(prefetch)
self.generator = generator self.generator = generator
self._gen_args = args or tuple() self._gen_args = args or tuple()
self._gen_kwargs = kwargs or dict() self._gen_kwargs = kwargs or {}
self.start() self.start()
def _run(self): def _run(self) -> None:
""" Put until queue size is reached. """ Run the :attr:`_generator` and put into the queue until until queue size is reached.
Note: put blocks only if put is called while queue has already
reached max size => this makes prefetch + thread_count prefetched items! Raises
N in the the queue, one waiting for insertion per thread! """ ------
Exception
If there is a failure to run the generator and put to the queue
"""
try: try:
for item in self.generator(*self._gen_args, **self._gen_kwargs): for item in self.generator(*self._gen_args, **self._gen_kwargs):
self.queue.put(item) self.queue.put(item)
@ -149,8 +274,14 @@ class BackgroundGenerator(MultiThread):
self.queue.put(None) self.queue.put(None)
raise raise
def iterator(self): def iterator(self) -> Generator:
""" Iterate items out of the queue """ """ Iterate items out of the queue
Yields
------
Any
The items from the generator
"""
while True: while True:
next_item = self.queue.get() next_item = self.queue.get()
self.check_and_raise_error() self.check_and_raise_error()

View File

@ -132,7 +132,7 @@ class DataGenerator():
""" """
logger.debug("do_shuffle: %s", do_shuffle) logger.debug("do_shuffle: %s", do_shuffle)
args = (do_shuffle, ) args = (do_shuffle, )
batcher = BackgroundGenerator(self._minibatch, thread_count=1, args=args) batcher = BackgroundGenerator(self._minibatch, args=args)
return batcher.iterator() return batcher.iterator()
# << INTERNAL METHODS >> # # << INTERNAL METHODS >> #

View File

@ -582,7 +582,7 @@ class _Feeder():
iterator[side] = self._load_generator(side, iterator[side] = self._load_generator(side,
True, True,
batch_size=batch_size, batch_size=batch_size,
images=imgs).minibatch_ab() images=imgs).minibatch_ab(do_shuffle=False)
logger.debug("Set time-lapse feed: %s", self._display_feeds["timelapse"]) logger.debug("Set time-lapse feed: %s", self._display_feeds["timelapse"])