mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 12:20:27 +01:00
Alignments Tool updates
- Copy info back to alignments file from faces
This commit is contained in:
parent
5805d76de4
commit
c79175cbde
|
|
@ -190,7 +190,7 @@ class Alignments():
|
|||
pts_time: List[float] = []
|
||||
keyframes: List[int] = []
|
||||
for idx, key in enumerate(sorted(self.data)):
|
||||
if not self.data[key]["video_meta"]:
|
||||
if not self.data[key].get("video_meta", {}):
|
||||
return retval
|
||||
meta = self.data[key]["video_meta"]
|
||||
pts_time.append(cast(float, meta["pts_time"]))
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from sklearn import decomposition
|
|||
from tqdm import tqdm
|
||||
|
||||
from .media import Faces, Frames
|
||||
from .jobs_faces import FaceToFile
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
|
|
@ -21,7 +22,7 @@ else:
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
from lib.align.alignments import PNGHeaderSourceDict
|
||||
from lib.align.alignments import PNGHeaderDict
|
||||
from .media import AlignmentData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -80,7 +81,7 @@ class Check():
|
|||
logger.debug("type: '%s', source_dir: '%s'", self._type, source_dir)
|
||||
return source_dir
|
||||
|
||||
def _get_items(self) -> Union[List[Dict[str, str]], List[Dict[str, "PNGHeaderSourceDict"]]]:
|
||||
def _get_items(self) -> Union[List[Dict[str, str]], List[Tuple[str, "PNGHeaderDict"]]]:
|
||||
""" Set the correct items to process
|
||||
|
||||
Returns
|
||||
|
|
@ -93,7 +94,7 @@ class Check():
|
|||
assert self._type is not None
|
||||
items: Union[Frames, Faces] = globals()[self._type.title()](self._source_dir)
|
||||
self._is_video = items.is_video
|
||||
return cast(Union[List[Dict[str, str]], List[Dict[str, "PNGHeaderSourceDict"]]],
|
||||
return cast(Union[List[Dict[str, str]], List[Tuple[str, "PNGHeaderDict"]]],
|
||||
items.file_list_sorted)
|
||||
|
||||
def process(self) -> None:
|
||||
|
|
@ -101,6 +102,13 @@ class Check():
|
|||
assert self._type is not None
|
||||
logger.info("[CHECK %s]", self._type.upper())
|
||||
items_output = self._compile_output()
|
||||
|
||||
if self._type == "faces":
|
||||
filelist = cast(List[Tuple[str, "PNGHeaderDict"]], self._items)
|
||||
check_update = FaceToFile(self._alignments, [val[1] for val in filelist])
|
||||
if check_update():
|
||||
self._alignments.save()
|
||||
|
||||
self._output_results(items_output)
|
||||
|
||||
def _validate(self) -> None:
|
||||
|
|
@ -185,12 +193,13 @@ class Check():
|
|||
The frame name and the face id of any frames which have multiple faces
|
||||
"""
|
||||
self.output_message = "Multiple faces in frame"
|
||||
for item in tqdm(cast(List[Tuple[str, "PNGHeaderSourceDict"]], self._items),
|
||||
for item in tqdm(cast(List[Tuple[str, "PNGHeaderDict"]], self._items),
|
||||
desc=self.output_message,
|
||||
leave=False):
|
||||
if not self._alignments.frame_has_multiple_faces(item["source_filename"]):
|
||||
src = item[1]["source"]
|
||||
if not self._alignments.frame_has_multiple_faces(src["source_filename"]):
|
||||
continue
|
||||
retval = (item[0], item[1]["face_index"])
|
||||
retval = (item[0], src["face_index"])
|
||||
logger.trace("Returning: '%s'", retval) # type:ignore
|
||||
yield retval
|
||||
|
||||
|
|
@ -222,7 +231,7 @@ class Check():
|
|||
The frame name of any frames in alignments with no matching file
|
||||
"""
|
||||
self.output_message = "Missing frames that are in alignments file"
|
||||
frames = set(item["frame_fullname"] for item in self._items)
|
||||
frames = set(item["frame_fullname"] for item in cast(List[Dict[str, str]], self._items))
|
||||
for frame in tqdm(self._alignments.data.keys(), desc=self.output_message, leave=False):
|
||||
if frame not in frames:
|
||||
logger.debug("Returning: '%s'", frame)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Tools for manipulating the alignments using extracted Faces as a source """
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from operator import itemgetter
|
||||
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
|
@ -15,10 +16,15 @@ from scripts.fsmedia import Alignments
|
|||
|
||||
from .media import Faces
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .media import AlignmentData
|
||||
from lib.align.alignments import (AlignmentDict, AlignmentFileDict,
|
||||
PNGHeaderDict, PNGHeaderSourceDict)
|
||||
PNGHeaderDict, PNGHeaderAlignmentsDict)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -37,7 +43,7 @@ class FromFaces(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Initializing %s: (alignments: %s, arguments: %s)",
|
||||
self.__class__.__name__, alignments, arguments)
|
||||
self._faces_dir = arguments.faces_dir
|
||||
self._faces = Faces(arguments.faces_dir, with_alignments=True)
|
||||
self._faces = Faces(arguments.faces_dir)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def process(self) -> None:
|
||||
|
|
@ -240,7 +246,7 @@ class Rename(): # pylint:disable=too-few-public-methods
|
|||
self.__class__.__name__, arguments, faces)
|
||||
self._alignments = alignments
|
||||
|
||||
kwargs: Dict[str, Union[bool, "AlignmentData"]] = dict(with_alignments=False)
|
||||
kwargs = {}
|
||||
if alignments.version < 2.1:
|
||||
# Update headers of faces generated with hash based alignments
|
||||
kwargs["alignments"] = alignments
|
||||
|
|
@ -254,14 +260,19 @@ class Rename(): # pylint:disable=too-few-public-methods
|
|||
def process(self) -> None:
|
||||
""" Process the face renaming """
|
||||
logger.info("[RENAME FACES]") # Tidy up cli output
|
||||
filelist = cast(List[Tuple[str, "PNGHeaderSourceDict"]], self._faces.file_list_sorted)
|
||||
rename_mappings = sorted([(face[0], face[1]["original_filename"])
|
||||
filelist = cast(List[Tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted)
|
||||
rename_mappings = sorted([(face[0], face[1]["source"]["original_filename"])
|
||||
for face in filelist
|
||||
if face[0] != face[1]["original_filename"]],
|
||||
if face[0] != face[1]["source"]["original_filename"]],
|
||||
key=lambda x: x[1])
|
||||
rename_count = self._rename_faces(rename_mappings)
|
||||
logger.info("%s faces renamed", rename_count)
|
||||
|
||||
filelist = cast(List[Tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted)
|
||||
copyback = FaceToFile(self._alignments, [val[1] for val in filelist])
|
||||
if copyback():
|
||||
self._alignments.save()
|
||||
|
||||
def _rename_faces(self, filename_mappings: List[Tuple[str, str]]) -> int:
|
||||
""" Rename faces back to their original name as exists in the alignments file.
|
||||
|
||||
|
|
@ -325,7 +336,7 @@ class RemoveFaces(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments)
|
||||
self._alignments = alignments
|
||||
|
||||
kwargs: Dict[str, Union[bool, "AlignmentData"]] = dict(with_alignments=False)
|
||||
kwargs = {}
|
||||
if alignments.version < 2.1:
|
||||
# Update headers of faces generated with hash based alignments
|
||||
kwargs["alignments"] = alignments
|
||||
|
|
@ -367,10 +378,11 @@ class RemoveFaces(): # pylint:disable=too-few-public-methods
|
|||
to like this and has a tendency to throw permission errors, so this remains single threaded
|
||||
for now.
|
||||
"""
|
||||
filelist = cast(List[Tuple[str, "PNGHeaderSourceDict"]], self._items.file_list_sorted)
|
||||
items = cast(Dict[str, List[int]], self._items.items)
|
||||
srcs = [(x[0], x[1]["source"])
|
||||
for x in cast(List[Tuple[str, "PNGHeaderDict"]], self._items.file_list_sorted)]
|
||||
to_update = [ # Items whose face index has changed
|
||||
x for x in filelist
|
||||
x for x in srcs
|
||||
if x[1]["face_index"] != items[x[1]["source_filename"]].index(x[1]["face_index"])]
|
||||
|
||||
for item in tqdm(to_update, desc="Updating PNG Headers", leave=False):
|
||||
|
|
@ -400,3 +412,80 @@ class RemoveFaces(): # pylint:disable=too-few-public-methods
|
|||
update_existing_metadata(fullpath, meta)
|
||||
|
||||
logger.info("%s Extracted face(s) had their header information updated", len(to_update))
|
||||
|
||||
|
||||
class FaceToFile(): # pylint:disable=too-few-public-methods
|
||||
""" Updates any optional/missing keys in the alignments file with any data that has been
|
||||
populated in a PNGHeader. Includes masks and identity fields.
|
||||
|
||||
Parameters
|
||||
---------
|
||||
alignments: :class:`tools.alignments.media.AlignmentsData`
|
||||
The loaded alignments containing faces to be removed
|
||||
face_data: list
|
||||
List of :class:`PNGHeaderDict` objects
|
||||
"""
|
||||
def __init__(self, alignments: "AlignmentData", face_data: List["PNGHeaderDict"]) -> None:
|
||||
logger.debug("Initializing %s: alignments: %s, face_data: %s",
|
||||
self.__class__.__name__, alignments, len(face_data))
|
||||
self._alignments = alignments
|
||||
self._face_alignments = face_data
|
||||
self._updatable_keys: List[Literal["identity", "mask"]] = ["identity", "mask"]
|
||||
self._counts: Dict[str, int] = {}
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def _check_and_update(self,
|
||||
alignment: "PNGHeaderAlignmentsDict",
|
||||
face: "AlignmentFileDict") -> None:
|
||||
""" Check whether the key requires updating and update it.
|
||||
|
||||
alignment: dict
|
||||
The alignment dictionary from the PNG Header
|
||||
face: dict
|
||||
The alignment dictionary for the face from the alignments file
|
||||
"""
|
||||
for key in self._updatable_keys:
|
||||
if key == "mask":
|
||||
exist_masks = face["mask"]
|
||||
for mask_name, mask_data in alignment["mask"].items():
|
||||
if mask_name in exist_masks:
|
||||
continue
|
||||
exist_masks[mask_name] = mask_data
|
||||
count_key = f"mask_{mask_name}"
|
||||
self._counts[count_key] = self._counts.get(count_key, 0) + 1
|
||||
continue
|
||||
|
||||
if not face.get(key, {}) and alignment.get(key):
|
||||
face[key] = alignment[key]
|
||||
self._counts[key] = self._counts.get(key, 0) + 1
|
||||
|
||||
def __call__(self) -> bool:
|
||||
""" Parse through the face data updating any entries in the alignments file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
``True`` if any alignment information was updated otherwise ``False``
|
||||
"""
|
||||
for meta in tqdm(self._face_alignments,
|
||||
desc="Updating Alignments File from PNG Header",
|
||||
leave=False):
|
||||
src = meta["source"]
|
||||
alignment = meta["alignments"]
|
||||
if not any(alignment.get(key, {}) for key in self._updatable_keys):
|
||||
continue
|
||||
|
||||
faces = self._alignments.get_faces_in_frame(src["source_filename"])
|
||||
if len(faces) < src["face_index"] + 1: # list index out of range
|
||||
logger.debug("Skipped face '%s'. Index does not exist in alignments file",
|
||||
src["original_filename"])
|
||||
continue
|
||||
|
||||
face = faces[src["face_index"]]
|
||||
self._check_and_update(alignment, face)
|
||||
|
||||
retval = False
|
||||
if self._counts:
|
||||
retval = True
|
||||
logger.info("Updated alignments file from PNG Data: %s", self._counts)
|
||||
return retval
|
||||
|
|
|
|||
|
|
@ -237,9 +237,10 @@ class Extract(): # pylint:disable=too-few-public-methods
|
|||
meta = self._alignments.video_meta_data
|
||||
has_meta = all(val is not None for val in meta.values())
|
||||
if has_meta:
|
||||
retval = None
|
||||
retval: Optional[int] = len(cast(Dict[str, Union[List[int], List[float]]],
|
||||
meta["pts_time"]))
|
||||
else:
|
||||
retval = len(cast(Dict[str, Union[List[int], List[float]]], meta["pts_time"]))
|
||||
retval = None
|
||||
logger.debug("Frame count from alignments file: (has_meta: %s, %s", has_meta, retval)
|
||||
return retval
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from lib.utils import _image_extensions, _video_extensions, FaceswapError
|
|||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from lib.align.alignments import AlignmentFileDict, PNGHeaderDict, PNGHeaderSourceDict
|
||||
from lib.align.alignments import AlignmentFileDict, PNGHeaderDict
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
|
@ -151,14 +151,12 @@ class MediaLoader():
|
|||
return retval
|
||||
|
||||
def sorted_items(self) -> Union[List[Dict[str, str]],
|
||||
List[Tuple[str, "PNGHeaderSourceDict"]],
|
||||
List[Tuple[str, "PNGHeaderDict"]]]:
|
||||
""" Override for specific folder processing """
|
||||
raise NotImplementedError()
|
||||
|
||||
def process_folder(self) -> Union[Generator[Dict[str, str], None, None],
|
||||
Generator[Tuple[str, "PNGHeaderDict"], None, None],
|
||||
Generator[Tuple[str, "PNGHeaderSourceDict"], None, None]]:
|
||||
Generator[Tuple[str, "PNGHeaderDict"], None, None]]:
|
||||
""" Override for specific folder processing """
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -265,21 +263,12 @@ class Faces(MediaLoader):
|
|||
The alignments object that contains the faces. Used to update legacy hash based faces
|
||||
for <v2.1 alignments to png header based version. Pass in ``None`` to not update legacy
|
||||
faces (raises error instead). Default: ``None``
|
||||
with_alignments: bool, optional
|
||||
By default, only the source information stored in the PNG header will be returned in
|
||||
:attr:`file_list_sorted`. Set to ``True`` to include alignment information as well.
|
||||
Default:``False``
|
||||
"""
|
||||
def __init__(self,
|
||||
folder: str,
|
||||
alignments: Optional[Alignments] = None,
|
||||
with_alignments: bool = False) -> None:
|
||||
def __init__(self, folder: str, alignments: Optional[Alignments] = None) -> None:
|
||||
self._alignments = alignments
|
||||
self._with_alignments = with_alignments
|
||||
super().__init__(folder)
|
||||
|
||||
def process_folder(self) -> Union[Generator[Tuple[str, "PNGHeaderDict"], None, None],
|
||||
Generator[Tuple[str, "PNGHeaderSourceDict"], None, None]]:
|
||||
def process_folder(self) -> Generator[Tuple[str, "PNGHeaderDict"], None, None]:
|
||||
""" Iterate through the faces folder pulling out various information for each face.
|
||||
|
||||
Yields
|
||||
|
|
@ -321,13 +310,11 @@ class Faces(MediaLoader):
|
|||
f"Some of the faces being passed in from '{self.folder}' could not be "
|
||||
f"matched to the alignments file '{self._alignments.file}'\nPlease double "
|
||||
"check your sources and try again.")
|
||||
sub_dict = data if self._with_alignments else data["source"]
|
||||
sub_dict = data
|
||||
else:
|
||||
sub_dict = (metadata["itxt"] if self._with_alignments
|
||||
else metadata["itxt"]["source"])
|
||||
sub_dict = cast("PNGHeaderDict", metadata["itxt"])
|
||||
|
||||
retval: Union[Tuple[str, "PNGHeaderDict"], Tuple[str, "PNGHeaderSourceDict"]]
|
||||
retval = (os.path.basename(fullpath), sub_dict) # type:ignore
|
||||
retval = (os.path.basename(fullpath), sub_dict)
|
||||
yield retval
|
||||
|
||||
def load_items(self) -> Dict[str, List[int]]:
|
||||
|
|
@ -339,19 +326,13 @@ class Faces(MediaLoader):
|
|||
The source filename as key with list of face indices for the frame as value
|
||||
"""
|
||||
faces: Dict[str, List[int]] = {}
|
||||
for face in cast(Union[List[Tuple[str, "PNGHeaderDict"]],
|
||||
List[Tuple[str, "PNGHeaderSourceDict"]]],
|
||||
self.file_list_sorted):
|
||||
src: "PNGHeaderSourceDict" = cast(
|
||||
"PNGHeaderDict",
|
||||
face[1])["source"] if self._with_alignments else cast("PNGHeaderSourceDict",
|
||||
face[1])
|
||||
for face in cast(List[Tuple[str, "PNGHeaderDict"]], self.file_list_sorted):
|
||||
src = face[1]["source"]
|
||||
faces.setdefault(src["source_filename"], []).append(src["face_index"])
|
||||
logger.trace(faces) # type: ignore
|
||||
return faces
|
||||
|
||||
def sorted_items(self) -> Union[List[Tuple[str, "PNGHeaderDict"]],
|
||||
List[Tuple[str, "PNGHeaderSourceDict"]]]:
|
||||
def sorted_items(self) -> List[Tuple[str, "PNGHeaderDict"]]:
|
||||
""" Return the items sorted by the saved file name.
|
||||
|
||||
Returns
|
||||
|
|
@ -359,9 +340,7 @@ class Faces(MediaLoader):
|
|||
list
|
||||
List of `dict` objects for each face found, sorted by the face's current filename
|
||||
"""
|
||||
items = cast(Union[List[Tuple[str, "PNGHeaderDict"]],
|
||||
List[Tuple[str, "PNGHeaderSourceDict"]]],
|
||||
sorted(self.process_folder(), key=itemgetter(0)))
|
||||
items = sorted(self.process_folder(), key=itemgetter(0))
|
||||
logger.trace(items) # type: ignore
|
||||
return items
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user