mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
* tools.alignments - add export job * plugins.extract: Update __repr__ for ExtractorBatch dataclass * plugins.extract: Initial implementation of external import plugins * plugins.extract: Disable lm masks on ROI alignment data import * lib.align: Add `landmark_type` property to AlignedFace and return dummy data for ROI Landmarks pose estimate * plugins.extract: Add centering config item for align import and fix filename mapping for images * plugins.extract: Log warning on downstream plugins on limited alignment data * tools: Fix plugins for 4 point ROI landmarks (alignments, sort, mask) * tools.manual: Fix for 2D-4 ROI landmarks * training: Fix for 4 point ROI landmarks * lib.convert: Average color plugin. Avoid divide by zero errors * extract - external: - Default detector to 'external' when importing alignments - Handle different frame origin co-ordinates * alignments: Store video extension in alignments file * plugins.extract.external: Handle video file keys * plugins.extract.external: Output warning if missing data * locales + docs * plugins.extract.align.external: Roll the corner points to top-left for different origins * Clean up * linting fix
270 lines
10 KiB
Python
270 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
""" Handles the generation of masks from faceswap for upating into an alignments file """
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import typing as T
|
|
|
|
from lib.image import encode_image, ImagesSaver
|
|
from lib.multithreading import MultiThread
|
|
from plugins.extract import Extractor
|
|
|
|
if T.TYPE_CHECKING:
|
|
from lib.align import Alignments, DetectedFace
|
|
from lib.align.alignments import PNGHeaderDict
|
|
from lib.queue_manager import EventQueue
|
|
from plugins.extract import ExtractMedia
|
|
from .loader import Loader
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MaskGenerator:
|
|
""" Uses faceswap's extract pipeline to generate masks and update them into the alignments file
|
|
and/or extracted face PNG Headers
|
|
|
|
Parameters
|
|
----------
|
|
mask_type: str
|
|
The mask type to generate
|
|
update_all: bool
|
|
``True`` to update all faces, ``False`` to only update faces missing masks
|
|
input_is_faces: bool
|
|
``True`` if the input are faceswap extracted faces otherwise ``False``
|
|
exclude_gpus: list[int]
|
|
List of any GPU IDs that should be excluded
|
|
loader: :class:`tools.mask.loader.Loader`
|
|
The loader for loading source images/video from disk
|
|
"""
|
|
def __init__(self,
|
|
mask_type: str,
|
|
update_all: bool,
|
|
input_is_faces: bool,
|
|
loader: Loader,
|
|
alignments: Alignments | None,
|
|
input_location: str,
|
|
exclude_gpus: list[int]) -> None:
|
|
logger.debug("Initializing %s (mask_type: %s, update_all: %s, input_is_faces: %s, "
|
|
"loader: %s, alignments: %s, input_location: %s, exclude_gpus: %s)",
|
|
self.__class__.__name__, mask_type, update_all, input_is_faces, loader,
|
|
alignments, input_location, exclude_gpus)
|
|
|
|
self._update_all = update_all
|
|
self._is_faces = input_is_faces
|
|
self._alignments = alignments
|
|
|
|
self._extractor = self._get_extractor(mask_type, exclude_gpus)
|
|
self._mask_type = self._set_correct_mask_type(mask_type)
|
|
self._input_thread = self._set_loader_thread(loader)
|
|
self._saver = ImagesSaver(input_location, as_bytes=True) if input_is_faces else None
|
|
|
|
self._counts: dict[T.Literal["face", "update"], int] = {"face": 0, "update": 0}
|
|
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
def _get_extractor(self, mask_type, exclude_gpus: list[int]) -> Extractor:
|
|
""" Obtain a Mask extractor plugin and launch it
|
|
|
|
Parameters
|
|
----------
|
|
mask_type: str
|
|
The mask type to generate
|
|
exclude_gpus: list or ``None``
|
|
A list of indices correlating to connected GPUs that Tensorflow should not use. Pass
|
|
``None`` to not exclude any GPUs.
|
|
|
|
Returns
|
|
-------
|
|
:class:`plugins.extract.pipeline.Extractor`:
|
|
The launched Extractor
|
|
"""
|
|
logger.debug("masker: %s", mask_type)
|
|
extractor = Extractor(None, None, mask_type, exclude_gpus=exclude_gpus)
|
|
extractor.launch()
|
|
logger.debug(extractor)
|
|
return extractor
|
|
|
|
def _set_correct_mask_type(self, mask_type: str) -> str:
|
|
""" Some masks have multiple variants that they can be saved as depending on config options
|
|
|
|
Parameters
|
|
----------
|
|
mask_type: str
|
|
The mask type to generate
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The actual mask variant to update
|
|
"""
|
|
if mask_type != "bisenet-fp":
|
|
return mask_type
|
|
|
|
# Hacky look up into masker to get the type of mask
|
|
mask_plugin = self._extractor._mask[0] # pylint:disable=protected-access
|
|
assert mask_plugin is not None
|
|
mtype = "head" if mask_plugin.config.get("include_hair", False) else "face"
|
|
new_type = f"{mask_type}_{mtype}"
|
|
logger.debug("Updating '%s' to '%s'", mask_type, new_type)
|
|
return new_type
|
|
|
|
def _needs_update(self, frame: str, idx: int, face: DetectedFace) -> bool:
|
|
""" Check if the mask for the current alignment needs updating for the requested mask_type
|
|
|
|
Parameters
|
|
----------
|
|
frame: str
|
|
The frame name in the alignments file
|
|
idx: int
|
|
The index of the face for this frame in the alignments file
|
|
face: :class:`~lib.align.DetectedFace`
|
|
The dected face object to check
|
|
|
|
Returns
|
|
-------
|
|
bool:
|
|
``True`` if the mask needs to be updated otherwise ``False``
|
|
"""
|
|
if self._update_all:
|
|
return True
|
|
|
|
retval = not face.mask or face.mask.get(self._mask_type, None) is None
|
|
|
|
logger.trace("Needs updating: %s, '%s' - %s", # type:ignore[attr-defined]
|
|
retval, frame, idx)
|
|
return retval
|
|
|
|
def _feed_extractor(self, loader: Loader, extract_queue: EventQueue) -> None:
|
|
""" Process to feed the extractor from inside a thread
|
|
|
|
Parameters
|
|
----------
|
|
loader: class:`tools.mask.loader.Loader`
|
|
The loader for loading source images/video from disk
|
|
extract_queue: :class:`lib.queue_manager.EventQueue`
|
|
The input queue to the extraction pipeline
|
|
"""
|
|
for media in loader.load():
|
|
self._counts["face"] += len(media.detected_faces)
|
|
|
|
if self._is_faces:
|
|
assert len(media.detected_faces) == 1
|
|
needs_update = self._needs_update(media.frame_metadata["source_filename"],
|
|
media.frame_metadata["face_index"],
|
|
media.detected_faces[0])
|
|
else:
|
|
# To keep face indexes correct/cover off where only one face in an image is missing
|
|
# a mask where there are multiple faces we process all faces again for any frames
|
|
# which have missing masks.
|
|
needs_update = any(self._needs_update(media.filename, idx, detected_face)
|
|
for idx, detected_face in enumerate(media.detected_faces))
|
|
|
|
if not needs_update:
|
|
logger.trace("No masks need updating in '%s'", # type:ignore[attr-defined]
|
|
media.filename)
|
|
continue
|
|
|
|
logger.trace("Passing to extractor: '%s'", media.filename) # type:ignore[attr-defined]
|
|
extract_queue.put(media)
|
|
|
|
logger.debug("Terminating loader thread")
|
|
extract_queue.put("EOF")
|
|
|
|
def _set_loader_thread(self, loader: Loader) -> MultiThread:
|
|
""" Set the iterator to load ExtractMedia objects into the mask extraction pipeline
|
|
so we can just iterate through the output masks
|
|
|
|
Parameters
|
|
----------
|
|
loader: class:`tools.mask.loader.Loader`
|
|
The loader for loading source images/video from disk
|
|
"""
|
|
in_queue = self._extractor.input_queue
|
|
logger.debug("Starting load thread: (loader: %s, queue: %s)", loader, in_queue)
|
|
in_thread = MultiThread(self._feed_extractor, loader, in_queue, thread_count=1)
|
|
in_thread.start()
|
|
logger.debug("Started load thread: %s", in_thread)
|
|
return in_thread
|
|
|
|
def _update_from_face(self, media: ExtractMedia) -> None:
|
|
""" Update the alignments file and/or the extracted face
|
|
|
|
Parameters
|
|
----------
|
|
media: :class:`~lib.extract.pipeline.ExtractMedia`
|
|
The ExtractMedia object with updated masks
|
|
"""
|
|
assert len(media.detected_faces) == 1
|
|
assert self._saver is not None
|
|
|
|
fname = media.frame_metadata["source_filename"]
|
|
idx = media.frame_metadata["face_index"]
|
|
face = media.detected_faces[0]
|
|
|
|
if self._alignments is not None:
|
|
logger.trace("Updating face %s in frame '%s'", idx, fname) # type:ignore[attr-defined]
|
|
self._alignments.update_face(fname, idx, face.to_alignment())
|
|
|
|
logger.trace("Updating extracted face: '%s'", media.filename) # type:ignore[attr-defined]
|
|
meta: PNGHeaderDict = {"alignments": face.to_png_meta(), "source": media.frame_metadata}
|
|
self._saver.save(media.filename, encode_image(media.image, ".png", metadata=meta))
|
|
|
|
def _update_from_frame(self, media: ExtractMedia) -> None:
|
|
""" Update the alignments file
|
|
|
|
Parameters
|
|
----------
|
|
media: :class:`~lib.extract.pipeline.ExtractMedia`
|
|
The ExtractMedia object with updated masks
|
|
"""
|
|
assert self._alignments is not None
|
|
fname = os.path.basename(media.filename)
|
|
logger.trace("Updating %s faces in frame '%s'", # type:ignore[attr-defined]
|
|
len(media.detected_faces), fname)
|
|
for idx, face in enumerate(media.detected_faces):
|
|
self._alignments.update_face(fname, idx, face.to_alignment())
|
|
|
|
def _finalize(self) -> None:
|
|
""" Close thread and save alignments on completion """
|
|
logger.debug("Finalizing MaskGenerator")
|
|
self._input_thread.join()
|
|
|
|
if self._counts["update"] > 0 and self._alignments is not None:
|
|
logger.debug("Saving alignments")
|
|
self._alignments.backup()
|
|
self._alignments.save()
|
|
|
|
if self._saver is not None:
|
|
logger.debug("Closing face saver")
|
|
self._saver.close()
|
|
|
|
if self._counts["update"] == 0:
|
|
logger.warning("No masks were updated of the %s faces seen", self._counts["face"])
|
|
else:
|
|
logger.info("Updated masks for %s faces of %s",
|
|
self._counts["update"], self._counts["face"])
|
|
|
|
def process(self) -> T.Generator[ExtractMedia, None, None]:
|
|
""" Process the output from the extractor pipeline
|
|
|
|
Yields
|
|
------
|
|
:class:`~lib.extract.pipeline.ExtractMedia`
|
|
The ExtractMedia object with updated masks
|
|
"""
|
|
for media in self._extractor.detected_faces():
|
|
self._input_thread.check_and_raise_error()
|
|
self._counts["update"] += len(media.detected_faces)
|
|
|
|
if self._is_faces:
|
|
self._update_from_face(media)
|
|
else:
|
|
self._update_from_frame(media)
|
|
|
|
yield media
|
|
|
|
self._finalize()
|
|
logger.debug("Completed MaskGenerator process")
|