mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
bugfix: convert - Process final items on truncated batch
This commit is contained in:
parent
8fdb856d05
commit
ae7793e876
|
|
@ -92,7 +92,7 @@ class PluginLoader():
|
|||
return PluginLoader._import("extract.mask", name, disable_logging)
|
||||
|
||||
@staticmethod
|
||||
def get_model(name: str, disable_logging: bool = False) -> "ModelBase":
|
||||
def get_model(name: str, disable_logging: bool = False) -> Type["ModelBase"]:
|
||||
""" Return requested training model plugin
|
||||
|
||||
Parameters
|
||||
|
|
|
|||
|
|
@ -898,11 +898,12 @@ class Predict():
|
|||
faces_seen = 0
|
||||
consecutive_no_faces = 0
|
||||
batch: List[ConvertItem] = []
|
||||
is_amd = get_backend() == "amd"
|
||||
while True:
|
||||
item: Union[Literal["EOF"], ConvertItem] = self._in_queue.get()
|
||||
if item == "EOF":
|
||||
logger.debug("EOF Received")
|
||||
if batch: # Process out any remaining items
|
||||
self._process_batch(batch, faces_seen)
|
||||
break
|
||||
logger.trace("Got from queue: '%s'", item.inbound.filename) # type:ignore
|
||||
faces_count = len(item.inbound.detected_faces)
|
||||
|
|
@ -928,22 +929,7 @@ class Predict():
|
|||
"consecutive_no_faces: %s", faces_seen, consecutive_no_faces)
|
||||
continue
|
||||
|
||||
if batch:
|
||||
logger.trace("Batching to predictor. Frames: %s, Faces: %s", # type:ignore
|
||||
len(batch), faces_seen)
|
||||
feed_batch = [feed_face for item in batch
|
||||
for feed_face in item.feed_faces]
|
||||
if faces_seen != 0:
|
||||
feed_faces = self._compile_feed_faces(feed_batch)
|
||||
batch_size = None
|
||||
if is_amd and feed_faces.shape[0] != self._batchsize:
|
||||
logger.verbose("Fallback to BS=1") # type:ignore
|
||||
batch_size = 1
|
||||
predicted = self._predict(feed_faces, batch_size)
|
||||
else:
|
||||
predicted = np.array([])
|
||||
|
||||
self._queue_out_frames(batch, predicted)
|
||||
self._process_batch(batch, faces_seen)
|
||||
|
||||
consecutive_no_faces = 0
|
||||
faces_seen = 0
|
||||
|
|
@ -953,6 +939,36 @@ class Predict():
|
|||
self._out_queue.put("EOF")
|
||||
logger.debug("Load queue complete")
|
||||
|
||||
def _process_batch(self, batch: List[ConvertItem], faces_seen: int):
|
||||
""" Predict faces on the given batch of images and queue out to patch thread
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch: list
|
||||
List of :class:`ConvertItem` objects for the current batch
|
||||
faces_seen: int
|
||||
The number of faces seen in the current batch
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`np.narray`
|
||||
The predicted faces for the current batch
|
||||
"""
|
||||
logger.trace("Batching to predictor. Frames: %s, Faces: %s", # type:ignore
|
||||
len(batch), faces_seen)
|
||||
feed_batch = [feed_face for item in batch for feed_face in item.feed_faces]
|
||||
if faces_seen != 0:
|
||||
feed_faces = self._compile_feed_faces(feed_batch)
|
||||
batch_size = None
|
||||
if get_backend() == "amd" and feed_faces.shape[0] != self._batchsize:
|
||||
logger.verbose("Fallback to BS=1") # type:ignore
|
||||
batch_size = 1
|
||||
predicted = self._predict(feed_faces, batch_size)
|
||||
else:
|
||||
predicted = np.array([])
|
||||
|
||||
self._queue_out_frames(batch, predicted)
|
||||
|
||||
def load_aligned(self, item: ConvertItem) -> None:
|
||||
""" Load the model's feed faces and the reference output faces.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user