Alignments update:

- Store face embeddings in PNG header when sorting
  - typing + refactor
  - Update alignments keys for 'identity' and 'video_meta' + bump to v2.3
  - General typing fixes
This commit is contained in:
torzdf 2022-09-25 18:22:48 +01:00
parent 376c419498
commit e5356a417e
6 changed files with 959 additions and 630 deletions

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,7 @@ import sys
import os
from hashlib import sha1
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import cast, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from zlib import compress, decompress
import cv2
@ -104,6 +104,7 @@ class DetectedFace():
self.top = top
self.height = height
self._landmarks_xy = landmarks_xy
self._identity: Dict[Literal["vggface2"], np.ndarray] = {}
self.thumbnail: Optional[np.ndarray] = None
self.mask = {} if mask is None else mask
self._training_masks: Optional[Tuple[bytes, Tuple[int, int, int]]] = None
@ -135,6 +136,11 @@ class DetectedFace():
assert self.top is not None and self.height is not None
return self.top + self.height
@property
def identity(self) -> Dict[Literal["vggface2"], np.ndarray]:
""" dict: Identity mechanism as key, identity embedding as value. """
return self._identity
def add_mask(self,
name: str,
mask: np.ndarray,
@ -173,6 +179,23 @@ class DetectedFace():
fsmask.add(mask, affine_matrix, interpolator)
self.mask[name] = fsmask
def add_identity(self, name: Literal["vggface2"], embedding: np.ndarray, ) -> None:
""" Add an identity embedding to this detected face. If an identity already exists for the
given :attr:`name` it will be overwritten
Parameters
----------
name: str
The name of the mechanism that calculated the identity
embedding: numpy.ndarray
The identity embedding
"""
logger.trace("name: '%s', embedding shape: %s", # type: ignore
name, embedding.shape)
assert name == "vggface2"
assert embedding.shape[0] == 512
self._identity[name] = embedding
def get_landmark_mask(self,
area: Literal["eye", "face", "mouth"],
blur_kernel: int,
@ -271,6 +294,7 @@ class DetectedFace():
landmarks_xy=self.landmarks_xy,
mask={name: mask.to_dict()
for name, mask in self.mask.items()},
identity={k: v.tolist() for k, v in self._identity.items()},
thumb=self.thumbnail)
logger.trace("Returning: %s", alignment) # type: ignore
return alignment
@ -306,6 +330,8 @@ class DetectedFace():
landmarks = alignment["landmarks_xy"]
if not isinstance(landmarks, np.ndarray):
landmarks = np.array(landmarks, dtype="float32")
self._identity = {cast(Literal["vggface2"], k): np.array(v, dtype="float32")
for k, v in alignment.get("identity", {}).items()}
self._landmarks_xy = landmarks.copy()
if with_thumb:
@ -340,7 +366,8 @@ class DetectedFace():
y=self.top,
h=self.height,
landmarks_xy=self.landmarks_xy.tolist(),
mask={name: mask.to_png_meta() for name, mask in self.mask.items()})
mask={name: mask.to_png_meta() for name, mask in self.mask.items()},
identity={k: v.tolist() for k, v in self._identity.items()})
return alignment
def from_png_meta(self, alignment: PNGHeaderAlignmentsDict) -> None:
@ -361,9 +388,14 @@ class DetectedFace():
for name, mask_dict in alignment["mask"].items():
self.mask[name] = Mask()
self.mask[name].from_dict(mask_dict)
self._identity = {}
for key, val in alignment.get("identity", {}).items():
assert key in ["vggface2"]
self._identity[cast(Literal["vggface2"], key)] = np.array(val, dtype="float32")
logger.trace("Created from png exif header: (left: %s, width: %s, top: %s " # type: ignore
" height: %s, andmarks: %s, mask: %s)", self.left, self.width, self.top,
self.height, self.landmarks_xy, self.mask)
" height: %s, landmarks: %s, mask: %s, identity: %s)", self.left, self.width,
self.top, self.height, self.landmarks_xy, self.mask,
{k: v.shape for k, v in self._identity.items()})
def _image_to_face(self, image: np.ndarray) -> None:
""" set self.image to be the cropped face from detected bounding box """

View File

@ -6,7 +6,7 @@ import os
import tkinter as tk
from tkinter import ttk
from typing import Union, List, Tuple
from typing import cast, Union, List, Optional, Tuple, TYPE_CHECKING
from math import ceil, floor
import numpy as np
@ -20,6 +20,9 @@ from matplotlib.backend_bases import NavigationToolbar2
from .custom_widgets import Tooltip
from .utils import get_config, get_images, LongRunningTask
if TYPE_CHECKING:
from matplotlib.lines import Line2D
matplotlib.use("TkAgg")
logger: logging.Logger = logging.getLogger(__name__)
@ -46,8 +49,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
self._ylabel = ylabel
self._colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper",
"summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"]
self._lines = []
self._toolbar = None
self._lines: List["Line2D"] = []
self._toolbar: Optional["NavigationToolbar"] = None
self._fig = Figure(figsize=(4, 4), dpi=75)
self._ax1 = self._fig.add_subplot(1, 1, 1)
@ -84,7 +87,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
Whether the graph should be initialized for the first time (``True``) or data is being
updated for an existing graph (``False``). Default: ``True``
"""
logger.trace("Updating plot")
logger.trace("Updating plot") # type:ignore
if initiate:
logger.debug("Initializing plot")
self._lines = []
@ -112,7 +115,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
if initiate:
self._legend_place()
logger.trace("Updated plot")
logger.trace("Updated plot") # type:ignore
def _axes_labels_set(self) -> None:
""" Set the X and Y axes labels. """
@ -145,12 +148,13 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
ymin, ymax = self._axes_data_get_min_max(data)
self._ax1.set_ylim(ymin, ymax)
self._ax1.set_xlim(xmin, xmax)
logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", ymin, ymax, xmax)
logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", # type:ignore
ymin, ymax, xmax)
else:
self._axes_limits_set_default()
@staticmethod
def _axes_data_get_min_max(data: List[float]) -> Tuple[float]:
def _axes_data_get_min_max(data: List[float]) -> Tuple[float, float]:
""" Obtain the minimum and maximum values for the y-axis from the given data points.
Parameters
@ -163,14 +167,14 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
tuple
The minimum and maximum values for the y axis
"""
ymin, ymax = [], []
ymins, ymaxs = [], []
for item in data: # TODO Handle as array not loop
ymin.append(np.nanmin(item) * 1000)
ymax.append(np.nanmax(item) * 1000)
ymin = floor(min(ymin)) / 1000
ymax = ceil(max(ymax)) / 1000
logger.trace("ymin: %s, ymax: %s", ymin, ymax)
ymins.append(np.nanmin(item) * 1000)
ymaxs.append(np.nanmax(item) * 1000)
ymin = floor(min(ymins)) / 1000
ymax = ceil(max(ymaxs)) / 1000
logger.trace("ymin: %s, ymax: %s", ymin, ymax) # type:ignore
return ymin, ymax
def _axes_set_yscale(self, scale: str) -> None:
@ -197,9 +201,9 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
list
A list of loss keys with their corresponding line formatting and color information
"""
logger.trace("Sorting lines")
raw_lines = []
sorted_lines = []
logger.trace("Sorting lines") # type:ignore
raw_lines: List[List[str]] = []
sorted_lines: List[List[str]] = []
for key in sorted(keys):
title = key.replace("_", " ").title()
if key.startswith("raw"):
@ -213,7 +217,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
return lines
@staticmethod
def _lines_groupsize(raw_lines: List[str], sorted_lines: List[str]) -> int:
def _lines_groupsize(raw_lines: List[List[str]], sorted_lines: List[List[str]]) -> int:
""" Get the number of items in each group.
If raw data isn't selected, then check the length of remaining groups until something is
@ -238,11 +242,11 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
keys = [key[0][:key[0].find("_")] for key in sorted_lines]
distinct_keys = set(keys)
groupsize = len(keys) // len(distinct_keys)
logger.trace(groupsize)
logger.trace(groupsize) # type:ignore
return groupsize
def _lines_style(self,
lines: List[str],
lines: List[List[str]],
groupsize: int) -> List[List[Union[str, int, Tuple[float]]]]:
""" Obtain the color map and line width for each group.
@ -258,14 +262,15 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
list
A list of loss keys with their corresponding line formatting and color information
"""
logger.trace("Setting lines style")
logger.trace("Setting lines style") # type:ignore
groups = int(len(lines) / groupsize)
colours = self._lines_create_colors(groupsize, groups)
widths = list(range(1, groups + 1))
for idx, item in enumerate(lines):
retval = cast(List[List[Union[str, int, Tuple[float]]]], lines)
for idx, item in enumerate(retval):
linewidth = widths[idx // groupsize]
item.extend((linewidth, colours[idx]))
return lines
return retval
def _lines_create_colors(self, groupsize: int, groups: int) -> List[Tuple[float]]:
""" Create the color maps.
@ -288,7 +293,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
cmap = matplotlib.cm.get_cmap(colour)
cpoint = 1 - (i / 5)
colours.append(cmap(cpoint))
logger.trace(colours)
logger.trace(colours) # type:ignore
return colours
def _legend_place(self) -> None:
@ -331,13 +336,13 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
super().__init__(parent, data, ylabel)
self._thread = None # Thread for LongRunningTask
self._displayed_keys = []
self._thread: Optional[LongRunningTask] = None # Thread for LongRunningTask
self._displayed_keys: List[str] = []
self._add_callback()
def _add_callback(self) -> None:
""" Add the variable trace to update graph on refresh button press or save iteration. """
get_config().tk_vars["refreshgraph"].trace("w", self.refresh)
get_config().tk_vars["refreshgraph"].trace("w", self.refresh) # type:ignore
def build(self) -> None:
""" Build the Training graph. """
@ -347,7 +352,7 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
def refresh(self, *args) -> None: # pylint: disable=unused-argument
""" Read the latest loss data and apply to current graph """
refresh_var = get_config().tk_vars["refreshgraph"]
refresh_var = cast(tk.BooleanVar, get_config().tk_vars["refreshgraph"])
if not refresh_var.get() and self._thread is None:
return
@ -402,8 +407,8 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
class Event(): # pylint: disable=too-few-public-methods
""" Event class that needs to be passed to plotcanvas.resize """
pass # pylint: disable=unnecessary-pass
Event.width = self.winfo_width()
Event.height = self.winfo_height()
setattr(Event, "width", self.winfo_width())
setattr(Event, "height", self.winfo_height())
self._plotcanvas.resize(Event) # pylint: disable=no-value-for-parameter
@ -485,7 +490,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
def __init__(self, # pylint: disable=super-init-not-called
canvas: FigureCanvasTkAgg,
window: SessionGraph,
window: ttk.Frame,
*,
pack_toolbar: bool = True) -> None:
@ -558,7 +563,10 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
img = get_images().icons[icon]
if not toggle:
btn = ttk.Button(frame, text=text, image=img, command=command)
btn: Union[ttk.Button, ttk.Checkbutton] = ttk.Button(frame,
text=text,
image=img,
command=command)
else:
var = tk.IntVar(master=frame)
btn = ttk.Checkbutton(frame, text=text, image=img, command=command, variable=var)

View File

@ -10,6 +10,8 @@ exclude = .git, __pycache__
ignore_missing_imports = True
[mypy-fastcluster.*]
ignore_missing_imports = True
[mypy-ffmpy.*]
ignore_missing_imports = True
[mypy-imageio.*]
ignore_missing_imports = True
[mypy-imageio_ffmpeg.*]
@ -32,6 +34,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-pynvx.*]
ignore_missing_imports = True
[mypy-pytest.*]
ignore_missing_imports = True
[mypy-scipy.*]
ignore_missing_imports = True
[mypy-sklearn.*]

View File

@ -15,7 +15,7 @@ import numpy as np
from tqdm import tqdm
from lib.align import AlignedFace, DetectedFace
from lib.image import FacesLoader, ImagesLoader, read_image_meta_batch
from lib.image import FacesLoader, ImagesLoader, read_image_meta_batch, update_existing_metadata
from lib.utils import FaceswapError
from plugins.extract.recognition.vgg_face2_keras import Cluster, VGGFace2 as VGGFace
@ -26,7 +26,7 @@ else:
if TYPE_CHECKING:
from argparse import Namespace
from lib.align.alignments import PNGHeaderAlignmentsDict
from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderSourceDict
logger = logging.getLogger(__name__)
@ -57,6 +57,7 @@ class InfoLoader():
self._iterator = None
self._description = "Reading image statistics..."
self._loader = ImagesLoader(input_dir) if info_type == "face" else FacesLoader(input_dir)
self._cached_source_data: Dict[str, "PNGHeaderSourceDict"] = {}
if self._loader.count == 0:
logger.error("No images to process in location: '%s'", input_dir)
sys.exit(1)
@ -100,12 +101,18 @@ class InfoLoader():
iterator = self._get_iterator()
return iterator
@classmethod
def _get_alignments(cls, metadata: Dict[str, Any]) -> Optional["PNGHeaderAlignmentsDict"]:
""" Obtain the alignments from a PNG Header
def _get_alignments(self,
filename: str,
metadata: Dict[str, Any]) -> Optional["PNGHeaderAlignmentsDict"]:
""" Obtain the alignments from a PNG Header.
The other image metadata is cached locally in case a sort method needs to write back to the
PNG header
Parameters
----------
filename: str
Full path to the image PNG file
metadata: dict
The header data from a PNG file
@ -114,8 +121,9 @@ class InfoLoader():
dict or ``None``
The alignments dictionary from the PNG header, if it exists, otherwise ``None``
"""
if not metadata or not metadata.get("alignments"):
if not metadata or not metadata.get("alignments") or not metadata.get("source"):
return None
self._cached_source_data[filename] = metadata["source"]
return metadata["alignments"]
def _metadata_reader(self) -> ImgMetaType:
@ -134,7 +142,7 @@ class InfoLoader():
total=self._loader.count,
desc=self._description,
leave=False):
alignments = self._get_alignments(metadata.get("itxt", {}))
alignments = self._get_alignments(filename, metadata.get("itxt", {}))
yield filename, None, alignments
def _full_data_reader(self) -> ImgMetaType:
@ -153,7 +161,7 @@ class InfoLoader():
desc=self._description,
total=self._loader.count,
leave=False):
alignments = self._get_alignments(metadata)
alignments = self._get_alignments(filename, metadata)
yield filename, image, alignments
def _image_data_reader(self) -> ImgMetaType:
@ -174,6 +182,28 @@ class InfoLoader():
leave=False):
yield filename, image, None
def update_png_header(self, filename: str, alignments: "PNGHeaderAlignmentsDict") -> None:
""" Update the PNG header of the given file with the given alignments.
NB: Header information can only be updated if the face is already on at least alignment
version 2.2. If below this version, then the header is not updated
Parameters
----------
filename: str
Full path to the PNG file to update
alignments: dict
The alignments to update into the PNG header
"""
vers = self._cached_source_data[filename]["alignments_version"]
if vers < 2.2:
return
self._cached_source_data[filename]["alignments_version"] = 2.3 if vers == 2.2 else vers
header = dict(alignments=alignments, source=self._cached_source_data[filename])
update_existing_metadata(filename, header)
class SortMethod():
""" Parent class for sort methods. All sort methods should inherit from this class
@ -805,13 +835,17 @@ class SortFace(SortMethod):
self._vgg_face = VGGFace(exclude_gpus=arguments.exclude_gpus)
self._vgg_face.init_model()
threshold = arguments.threshold
self._output_update_info = True
self._threshold: Optional[float] = 0.25 if threshold < 0 else threshold
def score_image(self,
filename: str,
image: Optional[np.ndarray],
alignments: Optional["PNGHeaderAlignmentsDict"]) -> None:
""" Processing logic for sort by face method
""" Processing logic for sort by face method.
Reads header information from the PNG file to look for VGGFace2 embedding. If it does not
exist, the embedding is obtained and added back into the PNG Header.
Parameters
----------
@ -822,23 +856,37 @@ class SortFace(SortMethod):
alignments: dict or ``None``
The alignments dictionary for the aligned face or ``None``
"""
if self._log_once:
msg = "Grouping" if self._is_group else "Sorting"
logger.info("%s by identity similarity...", msg)
self._log_once = False
if not alignments:
msg = ("The images to be sorted do not contain alignment data. Images must have "
"been generated by Faceswap's Extract process.\nIf you are sorting an "
"older faceset, then you should re-extract the faces from your source "
"alignments file to generate this data.")
raise FaceswapError(msg)
if self._log_once:
msg = "Grouping" if self._is_group else "Sorting"
logger.info("%s by identity similarity...", msg)
self._log_once = False
if alignments.get("identity", {}).get("vggface2"):
embedding = np.array(alignments["identity"]["vggface2"], dtype="float32")
self._result.append((filename, embedding))
return
if self._output_update_info:
logger.info("VGG Face2 Embeddings are being written to the image header. "
"Sorting by this method will be quicker next time")
self._output_update_info = False
face = AlignedFace(np.array(alignments["landmarks_xy"], dtype="float32"),
image=image,
centering="legacy",
size=self._vgg_face.input_size,
is_aligned=True).face
self._result.append((filename, self._vgg_face.predict(face)))
embedding = self._vgg_face.predict(face)
alignments.setdefault("identity", {})["vggface2"] = embedding.tolist()
self._iterator.update_png_header(filename, alignments)
self._result.append((filename, embedding))
def sort(self) -> None:
""" Sort by dendogram.

View File

@ -6,7 +6,7 @@ import logging
import operator
import sys
from typing import Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Dict, List, Optional, TYPE_CHECKING, Union
import numpy as np
from tqdm import tqdm
@ -22,11 +22,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
ImgMetaType = Generator[Tuple[str,
Optional[np.ndarray],
Optional["PNGHeaderAlignmentsDict"]], None, None]
class SortAlignedMetric(SortMethod): # pylint:disable=too-few-public-methods
""" Sort by comparison of metrics stored in an Aligned Face objects. This is a parent class
for sort by aligned metrics methods. Individual methods should inherit from this class