mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
Alignments update:
- Store face embeddings in PNG header when sorting - typing + refactor - Update alignments keys for 'identity' and 'video_meta' + bump to v2.3 - General typing fixes
This commit is contained in:
parent
376c419498
commit
e5356a417e
File diff suppressed because it is too large
Load Diff
|
|
@ -6,7 +6,7 @@ import sys
|
|||
import os
|
||||
|
||||
from hashlib import sha1
|
||||
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import cast, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from zlib import compress, decompress
|
||||
|
||||
import cv2
|
||||
|
|
@ -104,6 +104,7 @@ class DetectedFace():
|
|||
self.top = top
|
||||
self.height = height
|
||||
self._landmarks_xy = landmarks_xy
|
||||
self._identity: Dict[Literal["vggface2"], np.ndarray] = {}
|
||||
self.thumbnail: Optional[np.ndarray] = None
|
||||
self.mask = {} if mask is None else mask
|
||||
self._training_masks: Optional[Tuple[bytes, Tuple[int, int, int]]] = None
|
||||
|
|
@ -135,6 +136,11 @@ class DetectedFace():
|
|||
assert self.top is not None and self.height is not None
|
||||
return self.top + self.height
|
||||
|
||||
@property
|
||||
def identity(self) -> Dict[Literal["vggface2"], np.ndarray]:
|
||||
""" dict: Identity mechanism as key, identity embedding as value. """
|
||||
return self._identity
|
||||
|
||||
def add_mask(self,
|
||||
name: str,
|
||||
mask: np.ndarray,
|
||||
|
|
@ -173,6 +179,23 @@ class DetectedFace():
|
|||
fsmask.add(mask, affine_matrix, interpolator)
|
||||
self.mask[name] = fsmask
|
||||
|
||||
def add_identity(self, name: Literal["vggface2"], embedding: np.ndarray, ) -> None:
|
||||
""" Add an identity embedding to this detected face. If an identity already exists for the
|
||||
given :attr:`name` it will be overwritten
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the mechanism that calculated the identity
|
||||
embedding: numpy.ndarray
|
||||
The identity embedding
|
||||
"""
|
||||
logger.trace("name: '%s', embedding shape: %s", # type: ignore
|
||||
name, embedding.shape)
|
||||
assert name == "vggface2"
|
||||
assert embedding.shape[0] == 512
|
||||
self._identity[name] = embedding
|
||||
|
||||
def get_landmark_mask(self,
|
||||
area: Literal["eye", "face", "mouth"],
|
||||
blur_kernel: int,
|
||||
|
|
@ -271,6 +294,7 @@ class DetectedFace():
|
|||
landmarks_xy=self.landmarks_xy,
|
||||
mask={name: mask.to_dict()
|
||||
for name, mask in self.mask.items()},
|
||||
identity={k: v.tolist() for k, v in self._identity.items()},
|
||||
thumb=self.thumbnail)
|
||||
logger.trace("Returning: %s", alignment) # type: ignore
|
||||
return alignment
|
||||
|
|
@ -306,6 +330,8 @@ class DetectedFace():
|
|||
landmarks = alignment["landmarks_xy"]
|
||||
if not isinstance(landmarks, np.ndarray):
|
||||
landmarks = np.array(landmarks, dtype="float32")
|
||||
self._identity = {cast(Literal["vggface2"], k): np.array(v, dtype="float32")
|
||||
for k, v in alignment.get("identity", {}).items()}
|
||||
self._landmarks_xy = landmarks.copy()
|
||||
|
||||
if with_thumb:
|
||||
|
|
@ -340,7 +366,8 @@ class DetectedFace():
|
|||
y=self.top,
|
||||
h=self.height,
|
||||
landmarks_xy=self.landmarks_xy.tolist(),
|
||||
mask={name: mask.to_png_meta() for name, mask in self.mask.items()})
|
||||
mask={name: mask.to_png_meta() for name, mask in self.mask.items()},
|
||||
identity={k: v.tolist() for k, v in self._identity.items()})
|
||||
return alignment
|
||||
|
||||
def from_png_meta(self, alignment: PNGHeaderAlignmentsDict) -> None:
|
||||
|
|
@ -361,9 +388,14 @@ class DetectedFace():
|
|||
for name, mask_dict in alignment["mask"].items():
|
||||
self.mask[name] = Mask()
|
||||
self.mask[name].from_dict(mask_dict)
|
||||
self._identity = {}
|
||||
for key, val in alignment.get("identity", {}).items():
|
||||
assert key in ["vggface2"]
|
||||
self._identity[cast(Literal["vggface2"], key)] = np.array(val, dtype="float32")
|
||||
logger.trace("Created from png exif header: (left: %s, width: %s, top: %s " # type: ignore
|
||||
" height: %s, andmarks: %s, mask: %s)", self.left, self.width, self.top,
|
||||
self.height, self.landmarks_xy, self.mask)
|
||||
" height: %s, landmarks: %s, mask: %s, identity: %s)", self.left, self.width,
|
||||
self.top, self.height, self.landmarks_xy, self.mask,
|
||||
{k: v.shape for k, v in self._identity.items()})
|
||||
|
||||
def _image_to_face(self, image: np.ndarray) -> None:
|
||||
""" set self.image to be the cropped face from detected bounding box """
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import os
|
|||
import tkinter as tk
|
||||
|
||||
from tkinter import ttk
|
||||
from typing import Union, List, Tuple
|
||||
from typing import cast, Union, List, Optional, Tuple, TYPE_CHECKING
|
||||
from math import ceil, floor
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -20,6 +20,9 @@ from matplotlib.backend_bases import NavigationToolbar2
|
|||
from .custom_widgets import Tooltip
|
||||
from .utils import get_config, get_images, LongRunningTask
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
matplotlib.use("TkAgg")
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
|
@ -46,8 +49,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
self._ylabel = ylabel
|
||||
self._colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper",
|
||||
"summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"]
|
||||
self._lines = []
|
||||
self._toolbar = None
|
||||
self._lines: List["Line2D"] = []
|
||||
self._toolbar: Optional["NavigationToolbar"] = None
|
||||
self._fig = Figure(figsize=(4, 4), dpi=75)
|
||||
|
||||
self._ax1 = self._fig.add_subplot(1, 1, 1)
|
||||
|
|
@ -84,7 +87,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
Whether the graph should be initialized for the first time (``True``) or data is being
|
||||
updated for an existing graph (``False``). Default: ``True``
|
||||
"""
|
||||
logger.trace("Updating plot")
|
||||
logger.trace("Updating plot") # type:ignore
|
||||
if initiate:
|
||||
logger.debug("Initializing plot")
|
||||
self._lines = []
|
||||
|
|
@ -112,7 +115,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
|
||||
if initiate:
|
||||
self._legend_place()
|
||||
logger.trace("Updated plot")
|
||||
logger.trace("Updated plot") # type:ignore
|
||||
|
||||
def _axes_labels_set(self) -> None:
|
||||
""" Set the X and Y axes labels. """
|
||||
|
|
@ -145,12 +148,13 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
ymin, ymax = self._axes_data_get_min_max(data)
|
||||
self._ax1.set_ylim(ymin, ymax)
|
||||
self._ax1.set_xlim(xmin, xmax)
|
||||
logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", ymin, ymax, xmax)
|
||||
logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", # type:ignore
|
||||
ymin, ymax, xmax)
|
||||
else:
|
||||
self._axes_limits_set_default()
|
||||
|
||||
@staticmethod
|
||||
def _axes_data_get_min_max(data: List[float]) -> Tuple[float]:
|
||||
def _axes_data_get_min_max(data: List[float]) -> Tuple[float, float]:
|
||||
""" Obtain the minimum and maximum values for the y-axis from the given data points.
|
||||
|
||||
Parameters
|
||||
|
|
@ -163,14 +167,14 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
tuple
|
||||
The minimum and maximum values for the y axis
|
||||
"""
|
||||
ymin, ymax = [], []
|
||||
ymins, ymaxs = [], []
|
||||
|
||||
for item in data: # TODO Handle as array not loop
|
||||
ymin.append(np.nanmin(item) * 1000)
|
||||
ymax.append(np.nanmax(item) * 1000)
|
||||
ymin = floor(min(ymin)) / 1000
|
||||
ymax = ceil(max(ymax)) / 1000
|
||||
logger.trace("ymin: %s, ymax: %s", ymin, ymax)
|
||||
ymins.append(np.nanmin(item) * 1000)
|
||||
ymaxs.append(np.nanmax(item) * 1000)
|
||||
ymin = floor(min(ymins)) / 1000
|
||||
ymax = ceil(max(ymaxs)) / 1000
|
||||
logger.trace("ymin: %s, ymax: %s", ymin, ymax) # type:ignore
|
||||
return ymin, ymax
|
||||
|
||||
def _axes_set_yscale(self, scale: str) -> None:
|
||||
|
|
@ -197,9 +201,9 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
list
|
||||
A list of loss keys with their corresponding line formatting and color information
|
||||
"""
|
||||
logger.trace("Sorting lines")
|
||||
raw_lines = []
|
||||
sorted_lines = []
|
||||
logger.trace("Sorting lines") # type:ignore
|
||||
raw_lines: List[List[str]] = []
|
||||
sorted_lines: List[List[str]] = []
|
||||
for key in sorted(keys):
|
||||
title = key.replace("_", " ").title()
|
||||
if key.startswith("raw"):
|
||||
|
|
@ -213,7 +217,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
return lines
|
||||
|
||||
@staticmethod
|
||||
def _lines_groupsize(raw_lines: List[str], sorted_lines: List[str]) -> int:
|
||||
def _lines_groupsize(raw_lines: List[List[str]], sorted_lines: List[List[str]]) -> int:
|
||||
""" Get the number of items in each group.
|
||||
|
||||
If raw data isn't selected, then check the length of remaining groups until something is
|
||||
|
|
@ -238,11 +242,11 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
keys = [key[0][:key[0].find("_")] for key in sorted_lines]
|
||||
distinct_keys = set(keys)
|
||||
groupsize = len(keys) // len(distinct_keys)
|
||||
logger.trace(groupsize)
|
||||
logger.trace(groupsize) # type:ignore
|
||||
return groupsize
|
||||
|
||||
def _lines_style(self,
|
||||
lines: List[str],
|
||||
lines: List[List[str]],
|
||||
groupsize: int) -> List[List[Union[str, int, Tuple[float]]]]:
|
||||
""" Obtain the color map and line width for each group.
|
||||
|
||||
|
|
@ -258,14 +262,15 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
list
|
||||
A list of loss keys with their corresponding line formatting and color information
|
||||
"""
|
||||
logger.trace("Setting lines style")
|
||||
logger.trace("Setting lines style") # type:ignore
|
||||
groups = int(len(lines) / groupsize)
|
||||
colours = self._lines_create_colors(groupsize, groups)
|
||||
widths = list(range(1, groups + 1))
|
||||
for idx, item in enumerate(lines):
|
||||
retval = cast(List[List[Union[str, int, Tuple[float]]]], lines)
|
||||
for idx, item in enumerate(retval):
|
||||
linewidth = widths[idx // groupsize]
|
||||
item.extend((linewidth, colours[idx]))
|
||||
return lines
|
||||
return retval
|
||||
|
||||
def _lines_create_colors(self, groupsize: int, groups: int) -> List[Tuple[float]]:
|
||||
""" Create the color maps.
|
||||
|
|
@ -288,7 +293,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
cmap = matplotlib.cm.get_cmap(colour)
|
||||
cpoint = 1 - (i / 5)
|
||||
colours.append(cmap(cpoint))
|
||||
logger.trace(colours)
|
||||
logger.trace(colours) # type:ignore
|
||||
return colours
|
||||
|
||||
def _legend_place(self) -> None:
|
||||
|
|
@ -331,13 +336,13 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
|
|||
|
||||
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
|
||||
super().__init__(parent, data, ylabel)
|
||||
self._thread = None # Thread for LongRunningTask
|
||||
self._displayed_keys = []
|
||||
self._thread: Optional[LongRunningTask] = None # Thread for LongRunningTask
|
||||
self._displayed_keys: List[str] = []
|
||||
self._add_callback()
|
||||
|
||||
def _add_callback(self) -> None:
|
||||
""" Add the variable trace to update graph on refresh button press or save iteration. """
|
||||
get_config().tk_vars["refreshgraph"].trace("w", self.refresh)
|
||||
get_config().tk_vars["refreshgraph"].trace("w", self.refresh) # type:ignore
|
||||
|
||||
def build(self) -> None:
|
||||
""" Build the Training graph. """
|
||||
|
|
@ -347,7 +352,7 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
|
|||
|
||||
def refresh(self, *args) -> None: # pylint: disable=unused-argument
|
||||
""" Read the latest loss data and apply to current graph """
|
||||
refresh_var = get_config().tk_vars["refreshgraph"]
|
||||
refresh_var = cast(tk.BooleanVar, get_config().tk_vars["refreshgraph"])
|
||||
if not refresh_var.get() and self._thread is None:
|
||||
return
|
||||
|
||||
|
|
@ -402,8 +407,8 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
|
|||
class Event(): # pylint: disable=too-few-public-methods
|
||||
""" Event class that needs to be passed to plotcanvas.resize """
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
Event.width = self.winfo_width()
|
||||
Event.height = self.winfo_height()
|
||||
setattr(Event, "width", self.winfo_width())
|
||||
setattr(Event, "height", self.winfo_height())
|
||||
self._plotcanvas.resize(Event) # pylint: disable=no-value-for-parameter
|
||||
|
||||
|
||||
|
|
@ -485,7 +490,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
|
|||
|
||||
def __init__(self, # pylint: disable=super-init-not-called
|
||||
canvas: FigureCanvasTkAgg,
|
||||
window: SessionGraph,
|
||||
window: ttk.Frame,
|
||||
*,
|
||||
pack_toolbar: bool = True) -> None:
|
||||
|
||||
|
|
@ -558,7 +563,10 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
|
|||
img = get_images().icons[icon]
|
||||
|
||||
if not toggle:
|
||||
btn = ttk.Button(frame, text=text, image=img, command=command)
|
||||
btn: Union[ttk.Button, ttk.Checkbutton] = ttk.Button(frame,
|
||||
text=text,
|
||||
image=img,
|
||||
command=command)
|
||||
else:
|
||||
var = tk.IntVar(master=frame)
|
||||
btn = ttk.Checkbutton(frame, text=text, image=img, command=command, variable=var)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ exclude = .git, __pycache__
|
|||
ignore_missing_imports = True
|
||||
[mypy-fastcluster.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-ffmpy.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-imageio.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-imageio_ffmpeg.*]
|
||||
|
|
@ -32,6 +34,8 @@ ignore_missing_imports = True
|
|||
ignore_missing_imports = True
|
||||
[mypy-pynvx.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-pytest.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-scipy.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-sklearn.*]
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import numpy as np
|
|||
from tqdm import tqdm
|
||||
|
||||
from lib.align import AlignedFace, DetectedFace
|
||||
from lib.image import FacesLoader, ImagesLoader, read_image_meta_batch
|
||||
from lib.image import FacesLoader, ImagesLoader, read_image_meta_batch, update_existing_metadata
|
||||
from lib.utils import FaceswapError
|
||||
from plugins.extract.recognition.vgg_face2_keras import Cluster, VGGFace2 as VGGFace
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ else:
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
from lib.align.alignments import PNGHeaderAlignmentsDict
|
||||
from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderSourceDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -57,6 +57,7 @@ class InfoLoader():
|
|||
self._iterator = None
|
||||
self._description = "Reading image statistics..."
|
||||
self._loader = ImagesLoader(input_dir) if info_type == "face" else FacesLoader(input_dir)
|
||||
self._cached_source_data: Dict[str, "PNGHeaderSourceDict"] = {}
|
||||
if self._loader.count == 0:
|
||||
logger.error("No images to process in location: '%s'", input_dir)
|
||||
sys.exit(1)
|
||||
|
|
@ -100,12 +101,18 @@ class InfoLoader():
|
|||
iterator = self._get_iterator()
|
||||
return iterator
|
||||
|
||||
@classmethod
|
||||
def _get_alignments(cls, metadata: Dict[str, Any]) -> Optional["PNGHeaderAlignmentsDict"]:
|
||||
""" Obtain the alignments from a PNG Header
|
||||
def _get_alignments(self,
|
||||
filename: str,
|
||||
metadata: Dict[str, Any]) -> Optional["PNGHeaderAlignmentsDict"]:
|
||||
""" Obtain the alignments from a PNG Header.
|
||||
|
||||
The other image metadata is cached locally in case a sort method needs to write back to the
|
||||
PNG header
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename: str
|
||||
Full path to the image PNG file
|
||||
metadata: dict
|
||||
The header data from a PNG file
|
||||
|
||||
|
|
@ -114,8 +121,9 @@ class InfoLoader():
|
|||
dict or ``None``
|
||||
The alignments dictionary from the PNG header, if it exists, otherwise ``None``
|
||||
"""
|
||||
if not metadata or not metadata.get("alignments"):
|
||||
if not metadata or not metadata.get("alignments") or not metadata.get("source"):
|
||||
return None
|
||||
self._cached_source_data[filename] = metadata["source"]
|
||||
return metadata["alignments"]
|
||||
|
||||
def _metadata_reader(self) -> ImgMetaType:
|
||||
|
|
@ -134,7 +142,7 @@ class InfoLoader():
|
|||
total=self._loader.count,
|
||||
desc=self._description,
|
||||
leave=False):
|
||||
alignments = self._get_alignments(metadata.get("itxt", {}))
|
||||
alignments = self._get_alignments(filename, metadata.get("itxt", {}))
|
||||
yield filename, None, alignments
|
||||
|
||||
def _full_data_reader(self) -> ImgMetaType:
|
||||
|
|
@ -153,7 +161,7 @@ class InfoLoader():
|
|||
desc=self._description,
|
||||
total=self._loader.count,
|
||||
leave=False):
|
||||
alignments = self._get_alignments(metadata)
|
||||
alignments = self._get_alignments(filename, metadata)
|
||||
yield filename, image, alignments
|
||||
|
||||
def _image_data_reader(self) -> ImgMetaType:
|
||||
|
|
@ -174,6 +182,28 @@ class InfoLoader():
|
|||
leave=False):
|
||||
yield filename, image, None
|
||||
|
||||
def update_png_header(self, filename: str, alignments: "PNGHeaderAlignmentsDict") -> None:
|
||||
""" Update the PNG header of the given file with the given alignments.
|
||||
|
||||
NB: Header information can only be updated if the face is already on at least alignment
|
||||
version 2.2. If below this version, then the header is not updated
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename: str
|
||||
Full path to the PNG file to update
|
||||
alignments: dict
|
||||
The alignments to update into the PNG header
|
||||
"""
|
||||
vers = self._cached_source_data[filename]["alignments_version"]
|
||||
if vers < 2.2:
|
||||
return
|
||||
|
||||
self._cached_source_data[filename]["alignments_version"] = 2.3 if vers == 2.2 else vers
|
||||
header = dict(alignments=alignments, source=self._cached_source_data[filename])
|
||||
update_existing_metadata(filename, header)
|
||||
|
||||
|
||||
class SortMethod():
|
||||
""" Parent class for sort methods. All sort methods should inherit from this class
|
||||
|
|
@ -805,13 +835,17 @@ class SortFace(SortMethod):
|
|||
self._vgg_face = VGGFace(exclude_gpus=arguments.exclude_gpus)
|
||||
self._vgg_face.init_model()
|
||||
threshold = arguments.threshold
|
||||
self._output_update_info = True
|
||||
self._threshold: Optional[float] = 0.25 if threshold < 0 else threshold
|
||||
|
||||
def score_image(self,
|
||||
filename: str,
|
||||
image: Optional[np.ndarray],
|
||||
alignments: Optional["PNGHeaderAlignmentsDict"]) -> None:
|
||||
""" Processing logic for sort by face method
|
||||
""" Processing logic for sort by face method.
|
||||
|
||||
Reads header information from the PNG file to look for VGGFace2 embedding. If it does not
|
||||
exist, the embedding is obtained and added back into the PNG Header.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -822,23 +856,37 @@ class SortFace(SortMethod):
|
|||
alignments: dict or ``None``
|
||||
The alignments dictionary for the aligned face or ``None``
|
||||
"""
|
||||
if self._log_once:
|
||||
msg = "Grouping" if self._is_group else "Sorting"
|
||||
logger.info("%s by identity similarity...", msg)
|
||||
self._log_once = False
|
||||
|
||||
if not alignments:
|
||||
msg = ("The images to be sorted do not contain alignment data. Images must have "
|
||||
"been generated by Faceswap's Extract process.\nIf you are sorting an "
|
||||
"older faceset, then you should re-extract the faces from your source "
|
||||
"alignments file to generate this data.")
|
||||
raise FaceswapError(msg)
|
||||
|
||||
if self._log_once:
|
||||
msg = "Grouping" if self._is_group else "Sorting"
|
||||
logger.info("%s by identity similarity...", msg)
|
||||
self._log_once = False
|
||||
|
||||
if alignments.get("identity", {}).get("vggface2"):
|
||||
embedding = np.array(alignments["identity"]["vggface2"], dtype="float32")
|
||||
self._result.append((filename, embedding))
|
||||
return
|
||||
|
||||
if self._output_update_info:
|
||||
logger.info("VGG Face2 Embeddings are being written to the image header. "
|
||||
"Sorting by this method will be quicker next time")
|
||||
self._output_update_info = False
|
||||
|
||||
face = AlignedFace(np.array(alignments["landmarks_xy"], dtype="float32"),
|
||||
image=image,
|
||||
centering="legacy",
|
||||
size=self._vgg_face.input_size,
|
||||
is_aligned=True).face
|
||||
self._result.append((filename, self._vgg_face.predict(face)))
|
||||
embedding = self._vgg_face.predict(face)
|
||||
alignments.setdefault("identity", {})["vggface2"] = embedding.tolist()
|
||||
self._iterator.update_png_header(filename, alignments)
|
||||
self._result.append((filename, embedding))
|
||||
|
||||
def sort(self) -> None:
|
||||
""" Sort by dendogram.
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import logging
|
|||
import operator
|
||||
import sys
|
||||
|
||||
from typing import Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
|
@ -22,11 +22,6 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ImgMetaType = Generator[Tuple[str,
|
||||
Optional[np.ndarray],
|
||||
Optional["PNGHeaderAlignmentsDict"]], None, None]
|
||||
|
||||
|
||||
class SortAlignedMetric(SortMethod): # pylint:disable=too-few-public-methods
|
||||
""" Sort by comparison of metrics stored in an Aligned Face objects. This is a parent class
|
||||
for sort by aligned metrics methods. Individual methods should inherit from this class
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user