From ae7793e87667a94ee64fac580235cd27f34b38b7 Mon Sep 17 00:00:00 2001 From: torzdf <36920800+torzdf@users.noreply.github.com> Date: Fri, 2 Sep 2022 12:42:52 +0100 Subject: [PATCH] bugfix: convert - Process final items on truncated batch --- plugins/plugin_loader.py | 2 +- scripts/convert.py | 50 ++++++++++++++++++++++++++-------------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/plugins/plugin_loader.py b/plugins/plugin_loader.py index 44e55b6..bc3e09c 100644 --- a/plugins/plugin_loader.py +++ b/plugins/plugin_loader.py @@ -92,7 +92,7 @@ class PluginLoader(): return PluginLoader._import("extract.mask", name, disable_logging) @staticmethod - def get_model(name: str, disable_logging: bool = False) -> "ModelBase": + def get_model(name: str, disable_logging: bool = False) -> Type["ModelBase"]: """ Return requested training model plugin Parameters diff --git a/scripts/convert.py b/scripts/convert.py index 0e88303..4956172 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -898,11 +898,12 @@ class Predict(): faces_seen = 0 consecutive_no_faces = 0 batch: List[ConvertItem] = [] - is_amd = get_backend() == "amd" while True: item: Union[Literal["EOF"], ConvertItem] = self._in_queue.get() if item == "EOF": logger.debug("EOF Received") + if batch: # Process out any remaining items + self._process_batch(batch, faces_seen) break logger.trace("Got from queue: '%s'", item.inbound.filename) # type:ignore faces_count = len(item.inbound.detected_faces) @@ -928,22 +929,7 @@ class Predict(): "consecutive_no_faces: %s", faces_seen, consecutive_no_faces) continue - if batch: - logger.trace("Batching to predictor. Frames: %s, Faces: %s", # type:ignore - len(batch), faces_seen) - feed_batch = [feed_face for item in batch - for feed_face in item.feed_faces] - if faces_seen != 0: - feed_faces = self._compile_feed_faces(feed_batch) - batch_size = None - if is_amd and feed_faces.shape[0] != self._batchsize: - logger.verbose("Fallback to BS=1") # type:ignore - batch_size = 1 - predicted = self._predict(feed_faces, batch_size) - else: - predicted = np.array([]) - - self._queue_out_frames(batch, predicted) + self._process_batch(batch, faces_seen) consecutive_no_faces = 0 faces_seen = 0 @@ -953,6 +939,36 @@ class Predict(): self._out_queue.put("EOF") logger.debug("Load queue complete") + def _process_batch(self, batch: List[ConvertItem], faces_seen: int): + """ Predict faces on the given batch of images and queue out to patch thread + + Parameters + ---------- + batch: list + List of :class:`ConvertItem` objects for the current batch + faces_seen: int + The number of faces seen in the current batch + + Returns + ------- + :class:`np.narray` + The predicted faces for the current batch + """ + logger.trace("Batching to predictor. Frames: %s, Faces: %s", # type:ignore + len(batch), faces_seen) + feed_batch = [feed_face for item in batch for feed_face in item.feed_faces] + if faces_seen != 0: + feed_faces = self._compile_feed_faces(feed_batch) + batch_size = None + if get_backend() == "amd" and feed_faces.shape[0] != self._batchsize: + logger.verbose("Fallback to BS=1") # type:ignore + batch_size = 1 + predicted = self._predict(feed_faces, batch_size) + else: + predicted = np.array([]) + + self._queue_out_frames(batch, predicted) + def load_aligned(self, item: ConvertItem) -> None: """ Load the model's feed faces and the reference output faces.