bugfix: convert - Process final items on truncated batch

This commit is contained in:
torzdf 2022-09-02 12:42:52 +01:00
parent 8fdb856d05
commit ae7793e876
2 changed files with 34 additions and 18 deletions

View File

@ -92,7 +92,7 @@ class PluginLoader():
return PluginLoader._import("extract.mask", name, disable_logging) return PluginLoader._import("extract.mask", name, disable_logging)
@staticmethod @staticmethod
def get_model(name: str, disable_logging: bool = False) -> "ModelBase": def get_model(name: str, disable_logging: bool = False) -> Type["ModelBase"]:
""" Return requested training model plugin """ Return requested training model plugin
Parameters Parameters

View File

@ -898,11 +898,12 @@ class Predict():
faces_seen = 0 faces_seen = 0
consecutive_no_faces = 0 consecutive_no_faces = 0
batch: List[ConvertItem] = [] batch: List[ConvertItem] = []
is_amd = get_backend() == "amd"
while True: while True:
item: Union[Literal["EOF"], ConvertItem] = self._in_queue.get() item: Union[Literal["EOF"], ConvertItem] = self._in_queue.get()
if item == "EOF": if item == "EOF":
logger.debug("EOF Received") logger.debug("EOF Received")
if batch: # Process out any remaining items
self._process_batch(batch, faces_seen)
break break
logger.trace("Got from queue: '%s'", item.inbound.filename) # type:ignore logger.trace("Got from queue: '%s'", item.inbound.filename) # type:ignore
faces_count = len(item.inbound.detected_faces) faces_count = len(item.inbound.detected_faces)
@ -928,22 +929,7 @@ class Predict():
"consecutive_no_faces: %s", faces_seen, consecutive_no_faces) "consecutive_no_faces: %s", faces_seen, consecutive_no_faces)
continue continue
if batch: self._process_batch(batch, faces_seen)
logger.trace("Batching to predictor. Frames: %s, Faces: %s", # type:ignore
len(batch), faces_seen)
feed_batch = [feed_face for item in batch
for feed_face in item.feed_faces]
if faces_seen != 0:
feed_faces = self._compile_feed_faces(feed_batch)
batch_size = None
if is_amd and feed_faces.shape[0] != self._batchsize:
logger.verbose("Fallback to BS=1") # type:ignore
batch_size = 1
predicted = self._predict(feed_faces, batch_size)
else:
predicted = np.array([])
self._queue_out_frames(batch, predicted)
consecutive_no_faces = 0 consecutive_no_faces = 0
faces_seen = 0 faces_seen = 0
@ -953,6 +939,36 @@ class Predict():
self._out_queue.put("EOF") self._out_queue.put("EOF")
logger.debug("Load queue complete") logger.debug("Load queue complete")
def _process_batch(self, batch: List[ConvertItem], faces_seen: int):
""" Predict faces on the given batch of images and queue out to patch thread
Parameters
----------
batch: list
List of :class:`ConvertItem` objects for the current batch
faces_seen: int
The number of faces seen in the current batch
Returns
-------
:class:`np.narray`
The predicted faces for the current batch
"""
logger.trace("Batching to predictor. Frames: %s, Faces: %s", # type:ignore
len(batch), faces_seen)
feed_batch = [feed_face for item in batch for feed_face in item.feed_faces]
if faces_seen != 0:
feed_faces = self._compile_feed_faces(feed_batch)
batch_size = None
if get_backend() == "amd" and feed_faces.shape[0] != self._batchsize:
logger.verbose("Fallback to BS=1") # type:ignore
batch_size = 1
predicted = self._predict(feed_faces, batch_size)
else:
predicted = np.array([])
self._queue_out_frames(batch, predicted)
def load_aligned(self, item: ConvertItem) -> None: def load_aligned(self, item: ConvertItem) -> None:
""" Load the model's feed faces and the reference output faces. """ Load the model's feed faces and the reference output faces.