mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 12:20:27 +01:00
bugfix: convert - Process final items on truncated batch
This commit is contained in:
parent
8fdb856d05
commit
ae7793e876
|
|
@ -92,7 +92,7 @@ class PluginLoader():
|
||||||
return PluginLoader._import("extract.mask", name, disable_logging)
|
return PluginLoader._import("extract.mask", name, disable_logging)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(name: str, disable_logging: bool = False) -> "ModelBase":
|
def get_model(name: str, disable_logging: bool = False) -> Type["ModelBase"]:
|
||||||
""" Return requested training model plugin
|
""" Return requested training model plugin
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|
|
||||||
|
|
@ -898,11 +898,12 @@ class Predict():
|
||||||
faces_seen = 0
|
faces_seen = 0
|
||||||
consecutive_no_faces = 0
|
consecutive_no_faces = 0
|
||||||
batch: List[ConvertItem] = []
|
batch: List[ConvertItem] = []
|
||||||
is_amd = get_backend() == "amd"
|
|
||||||
while True:
|
while True:
|
||||||
item: Union[Literal["EOF"], ConvertItem] = self._in_queue.get()
|
item: Union[Literal["EOF"], ConvertItem] = self._in_queue.get()
|
||||||
if item == "EOF":
|
if item == "EOF":
|
||||||
logger.debug("EOF Received")
|
logger.debug("EOF Received")
|
||||||
|
if batch: # Process out any remaining items
|
||||||
|
self._process_batch(batch, faces_seen)
|
||||||
break
|
break
|
||||||
logger.trace("Got from queue: '%s'", item.inbound.filename) # type:ignore
|
logger.trace("Got from queue: '%s'", item.inbound.filename) # type:ignore
|
||||||
faces_count = len(item.inbound.detected_faces)
|
faces_count = len(item.inbound.detected_faces)
|
||||||
|
|
@ -928,22 +929,7 @@ class Predict():
|
||||||
"consecutive_no_faces: %s", faces_seen, consecutive_no_faces)
|
"consecutive_no_faces: %s", faces_seen, consecutive_no_faces)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if batch:
|
self._process_batch(batch, faces_seen)
|
||||||
logger.trace("Batching to predictor. Frames: %s, Faces: %s", # type:ignore
|
|
||||||
len(batch), faces_seen)
|
|
||||||
feed_batch = [feed_face for item in batch
|
|
||||||
for feed_face in item.feed_faces]
|
|
||||||
if faces_seen != 0:
|
|
||||||
feed_faces = self._compile_feed_faces(feed_batch)
|
|
||||||
batch_size = None
|
|
||||||
if is_amd and feed_faces.shape[0] != self._batchsize:
|
|
||||||
logger.verbose("Fallback to BS=1") # type:ignore
|
|
||||||
batch_size = 1
|
|
||||||
predicted = self._predict(feed_faces, batch_size)
|
|
||||||
else:
|
|
||||||
predicted = np.array([])
|
|
||||||
|
|
||||||
self._queue_out_frames(batch, predicted)
|
|
||||||
|
|
||||||
consecutive_no_faces = 0
|
consecutive_no_faces = 0
|
||||||
faces_seen = 0
|
faces_seen = 0
|
||||||
|
|
@ -953,6 +939,36 @@ class Predict():
|
||||||
self._out_queue.put("EOF")
|
self._out_queue.put("EOF")
|
||||||
logger.debug("Load queue complete")
|
logger.debug("Load queue complete")
|
||||||
|
|
||||||
|
def _process_batch(self, batch: List[ConvertItem], faces_seen: int):
|
||||||
|
""" Predict faces on the given batch of images and queue out to patch thread
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch: list
|
||||||
|
List of :class:`ConvertItem` objects for the current batch
|
||||||
|
faces_seen: int
|
||||||
|
The number of faces seen in the current batch
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
:class:`np.narray`
|
||||||
|
The predicted faces for the current batch
|
||||||
|
"""
|
||||||
|
logger.trace("Batching to predictor. Frames: %s, Faces: %s", # type:ignore
|
||||||
|
len(batch), faces_seen)
|
||||||
|
feed_batch = [feed_face for item in batch for feed_face in item.feed_faces]
|
||||||
|
if faces_seen != 0:
|
||||||
|
feed_faces = self._compile_feed_faces(feed_batch)
|
||||||
|
batch_size = None
|
||||||
|
if get_backend() == "amd" and feed_faces.shape[0] != self._batchsize:
|
||||||
|
logger.verbose("Fallback to BS=1") # type:ignore
|
||||||
|
batch_size = 1
|
||||||
|
predicted = self._predict(feed_faces, batch_size)
|
||||||
|
else:
|
||||||
|
predicted = np.array([])
|
||||||
|
|
||||||
|
self._queue_out_frames(batch, predicted)
|
||||||
|
|
||||||
def load_aligned(self, item: ConvertItem) -> None:
|
def load_aligned(self, item: ConvertItem) -> None:
|
||||||
""" Load the model's feed faces and the reference output faces.
|
""" Load the model's feed faces and the reference output faces.
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user