faceswap/plugins/extract/align/external.py
torzdf 1c081aea7d
Add ability to export and import alignment data (#1383)
* tools.alignments - add export job

* plugins.extract: Update __repr__ for ExtractorBatch dataclass

* plugins.extract: Initial implementation of external import plugins

* plugins.extract: Disable lm masks on ROI alignment data import

* lib.align: Add `landmark_type` property to AlignedFace and return dummy data for ROI Landmarks pose estimate

* plugins.extract: Add centering config item for align import and fix filename mapping for images

* plugins.extract: Log warning on downstream plugins on limited alignment data

* tools: Fix plugins for 4 point ROI landmarks (alignments, sort, mask)

* tools.manual: Fix for 2D-4 ROI landmarks

* training: Fix for 4 point ROI landmarks

* lib.convert: Average color plugin. Avoid divide by zero errors

* extract - external:
  - Default detector to 'external' when importing alignments
  - Handle different frame origin co-ordinates

* alignments: Store video extension in alignments file

* plugins.extract.external: Handle video file keys

* plugins.extract.external: Output warning if missing data

* locales + docs

* plugins.extract.align.external: Roll the corner points to top-left for different origins

* Clean up

* linting fix
2024-04-15 12:19:15 +01:00

278 lines
11 KiB
Python

#!/usr/bin/env python3
""" Import 68 point landmarks or ROI boxes from a json file """
import logging
import typing as T
import os
import re
import numpy as np
from lib.align import EXTRACT_RATIOS, LandmarkType
from lib.utils import FaceswapError, IMAGE_EXTENSIONS
from ._base import BatchType, Aligner, AlignerBatch
logger = logging.getLogger(__name__)
class Align(Aligner):
""" Import face detection bounding boxes from an external json file """
def __init__(self, **kwargs) -> None:
kwargs["normalize_method"] = None # Disable normalization
kwargs["re_feed"] = 0 # Disable re-feed
kwargs["re_align"] = False # Disablle re-align
kwargs["disable_filter"] = True # Disable aligner filters
super().__init__(git_model_id=None, model_filename=None, **kwargs)
self.name = "External"
self.batchsize = 16
self._origin: T.Literal["top-left",
"bottom-left",
"top-right",
"bottom-right"] = self.config["origin"]
self._re_frame_no: re.Pattern = re.compile(r"\d+$")
self._is_video: bool = False
self._imported: dict[str | int, tuple[int, np.ndarray]] = {}
"""dict[str | int, tuple[int, np.ndarray]]: filename as key, value of [number of faces
remaining for the frame, all landmarks in the frame] """
self._missing: list[str] = []
self._roll: dict[T.Literal["bottom-left", "top-right", "bottom-right"], int] = {
"bottom-left": 3, "top-right": 1, "bottom-right": 2}
"""dict[Literal["bottom-left", "top-right", "bottom-right"], int]: Amount to roll the
points by for different origins when 4 Point ROI landmarks are provided """
centering = self.config["4_point_centering"]
self._adjustment: float = 1. if centering is None else 1. - EXTRACT_RATIOS[centering]
"""float: The amount to adjust 4 point ROI landmarks to standardize the points for a
'head' sized extracted face """
def init_model(self) -> None:
""" No initialization to perform """
logger.debug("No aligner model to initialize")
def _check_for_video(self, filename: str) -> None:
""" Check a sample filename from the import file for a file extension to set
:attr:`_is_video`
Parameters
----------
filename: str
A sample file name from the imported data
"""
logger.debug("Checking for video from '%s'", filename)
ext = os.path.splitext(filename)[-1]
if ext.lower() not in IMAGE_EXTENSIONS:
self._is_video = True
logger.debug("Set is_video to %s from extension '%s'", self._is_video, ext)
def _get_key(self, key: str) -> str | int:
""" Obtain the key for the item in the lookup table. If the input are images, the key will
be the image filename. If the input is a video, the key will be the frame number
Parameters
----------
key: str
The initial key value from import data or an import image/frame
Returns
-------
str | int
The filename is the input data is images, otherwise the frame number of a video
"""
if not self._is_video:
return key
original_name = os.path.splitext(key)[0]
matches = self._re_frame_no.findall(original_name)
if not matches or len(matches) > 1:
raise FaceswapError(f"Invalid import name: '{key}'. For video files, the key should "
"end with the frame number.")
retval = int(matches[0])
logger.trace("Obtained frame number %s from key '%s'", # type:ignore[attr-defined]
retval, key)
return retval
def _import_face(self, face: dict[str, list[int] | list[list[float]]]) -> np.ndarray:
""" Import the landmarks from a single face
Parameters
----------
face: dict[str, list[int] | list[list[float]]]
An import dictionary item for a face
Returns
-------
:class:`numpy.ndarray`
The landmark data imported from the json file
Raises
------
FaceSwapError
If the landmarks_2d key does not exist or the landmarks are in an incorrect format
"""
landmarks = face.get("landmarks_2d")
if landmarks is None:
raise FaceswapError("The provided import file is the required key 'landmarks_2d")
if len(landmarks) not in (4, 68):
raise FaceswapError("Imported 'landmarks_2d' should be either 68 facial feature "
"landmarks or 4 ROI corner locations")
retval = np.array(landmarks, dtype="float32")
if retval.shape[-1] != 2:
raise FaceswapError("Imported 'landmarks_2d' should be formatted as a list of (x, y) "
"co-ordinates")
if retval.shape[0] == 4: # Adjust ROI landmarks based on centering selected
center = np.mean(retval, axis=0)
retval = (retval - center) * self._adjustment + center
return retval
def import_data(self, data: dict[str, list[dict[str, list[int] | list[list[float]]]]]) -> None:
""" Import the aligner data from the json import file and set to :attr:`_imported`
Parameters
----------
data: dict[str, list[dict[str, list[int] | list[list[float]]]]]
The data to be imported
"""
logger.debug("Data length: %s", len(data))
self._check_for_video(list(data)[0])
for key, faces in data.items():
try:
lms = np.array([self._import_face(face) for face in faces], dtype="float32")
if not np.any(lms):
logger.trace("Skipping frame '%s' with no faces") # type:ignore[attr-defined]
continue
store_key = self._get_key(key)
self._imported[store_key] = (lms.shape[0], lms)
except FaceswapError as err:
logger.error(str(err))
msg = f"The imported frame key that failed was '{key}'"
raise FaceswapError(msg) from err
lm_shape = set(v[1].shape[1:] for v in self._imported.values() if v[0] > 0)
if len(lm_shape) > 1:
raise FaceswapError("All external data should have the same number of landmarks. "
f"Found landmarks of shape: {lm_shape}")
if (4, 2) in lm_shape:
self.landmark_type = LandmarkType.LM_2D_4
def process_input(self, batch: BatchType) -> None:
""" Put the filenames and original frame dimensions into `batch.feed` so they can be
collected for mapping in `.predict`
Parameters
----------
batch: :class:`~plugins.extract.detect._base.AlignerBatch`
The batch to be processed by the plugin
"""
batch.feed = np.array([(self._get_key(os.path.basename(f)), i.shape[:2])
for f, i in zip(batch.filename, batch.image)], dtype="object")
def faces_to_feed(self, faces: np.ndarray) -> np.ndarray:
""" No action required for import plugin
Parameters
----------
faces: :class:`numpy.ndarray`
The batch of faces in UINT8 format
Returns
-------
class: `numpy.ndarray`
the original batch of faces
"""
return faces
def _adjust_for_origin(self, landmarks: np.ndarray, frame_dims: tuple[int, int]) -> np.ndarray:
""" Adjust the landmarks to be top-left orientated based on the selected import origin
Parameters
----------
landmarks: :class:`np.ndarray`
The imported facial landmarks box at original (0, 0) origin
frame_dims: tuple[int, int]
The (rows, columns) dimensions of the original frame
Returns
-------
:class:`numpy.ndarray`
The adjusted landmarks box for a top-left origin
"""
if not np.any(landmarks) or self._origin == "top-left":
return landmarks
if LandmarkType.from_shape(landmarks.shape) == LandmarkType.LM_2D_4:
landmarks = np.roll(landmarks, self._roll[self._origin], axis=0)
if self._origin.startswith("bottom"):
landmarks[:, 1] = frame_dims[0] - landmarks[:, 1]
if self._origin.endswith("right"):
landmarks[:, 0] = frame_dims[1] - landmarks[:, 0]
return landmarks
def predict(self, feed: np.ndarray) -> np.ndarray:
""" Pair the input filenames to the import file
Parameters
----------
feed: :class:`numpy.ndarray`
The filenames in the batch to return imported alignments for
Returns
-------
:class:`numpy.ndarray`
The predictions for the given filenames
"""
preds = []
for key, frame_dims in feed:
if key not in self._imported:
self._missing.append(key)
continue
remaining, all_lms = self._imported[key]
preds.append(self._adjust_for_origin(all_lms[all_lms.shape[0] - remaining],
frame_dims))
if remaining == 1:
del self._imported[key]
else:
self._imported[key] = (remaining - 1, all_lms)
return np.array(preds, dtype="float32")
def process_output(self, batch: BatchType) -> None:
""" Process the imported data to the landmarks attribute
Parameters
----------
batch: :class:`AlignerBatch`
The current batch from the model with :attr:`predictions` populated
"""
assert isinstance(batch, AlignerBatch)
batch.landmarks = batch.prediction
logger.trace("Imported landmarks: %s", batch.landmarks) # type:ignore[attr-defined]
def on_completion(self) -> None:
""" Output information if:
- Imported items were not matched in input data
- Input data was not matched in imported items
"""
super().on_completion()
if self._missing:
logger.warning("[ALIGN] %s input frames could not be matched in the import file "
"'%s'. Run in verbose mode for a list of frames.",
len(self._missing), self.config["file_name"])
logger.verbose( # type:ignore[attr-defined]
"[ALIGN] Input frames not in import file: %s", self._missing)
if self._imported:
logger.warning("[ALIGN] %s items in the import file '%s' could not be matched to any "
"input frames. Run in verbose mode for a list of items.",
len(self._imported), self.config["file_name"])
logger.verbose( # type:ignore[attr-defined]
"[ALIGN] import file items not in input frames: %s", list(self._imported))