From 66ed005ef3d824b4f8221f7fe5019ffca770a94e Mon Sep 17 00:00:00 2001 From: torzdf <36920800+torzdf@users.noreply.github.com> Date: Tue, 24 Sep 2019 12:16:05 +0100 Subject: [PATCH] Optimize Data Augmentation (#881) * Move image utils to lib.image * Add .pylintrc file * Remove some cv2 pylint ignores * TrainingData: Load images from disk in batches * TrainingData: get_landmarks to batch * TrainingData: transform and flip to batches * TrainingData: Optimize color augmentation * TrainingData: Optimize target and random_warp * TrainingData - Convert _get_closest_match for batching * TrainingData: Warp To Landmarks optimized * Save models to threadpoolexecutor * Move stack_images, Rename ImageManipulation. ImageAugmentation Docstrings * Masks: Set dtype and threshold for lib.masks based on input face * Docstrings and Documentation --- .gitignore | 1 + .pylintrc | 570 +++++++++++++++++++ docs/full/lib.image.rst | 7 + docs/full/lib.rst | 2 + docs/full/lib.training_data.rst | 7 + docs/index.rst | 2 +- lib/alignments.py | 2 +- lib/face_filter.py | 6 +- lib/faces_detect.py | 87 +++ lib/image.py | 302 ++++++++++ lib/model/masks.py | 20 +- lib/training_data.py | 970 ++++++++++++++++++++------------ lib/utils.py | 242 +------- plugins/extract/detect/_base.py | 3 +- plugins/train/model/_base.py | 24 +- plugins/train/trainer/_base.py | 75 ++- scripts/convert.py | 7 +- scripts/extract.py | 5 +- scripts/fsmedia.py | 9 +- scripts/train.py | 7 +- tools/lib_alignments/media.py | 10 +- tools/sort.py | 16 +- 22 files changed, 1709 insertions(+), 665 deletions(-) create mode 100644 .pylintrc create mode 100644 docs/full/lib.image.rst create mode 100644 docs/full/lib.training_data.rst create mode 100644 lib/image.py diff --git a/.gitignore b/.gitignore index 09b2c8d..0cbbddf 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ !plugins/extract/* !plugins/train/* !plugins/convert/* +!.pylintrc !tools !tools/lib* !_travis diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..6907995 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,570 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist=cv2 + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Specify a configuration file. +#rcfile= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=print-statement, + parameter-unpacking, + unpacking-in-except, + old-raise-syntax, + backtick, + long-suffix, + old-ne-operator, + old-octal-literal, + import-star-module-level, + non-ascii-bytes-literal, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + apply-builtin, + basestring-builtin, + buffer-builtin, + cmp-builtin, + coerce-builtin, + execfile-builtin, + file-builtin, + long-builtin, + raw_input-builtin, + reduce-builtin, + standarderror-builtin, + unicode-builtin, + xrange-builtin, + coerce-method, + delslice-method, + getslice-method, + setslice-method, + no-absolute-import, + old-division, + dict-iter-method, + dict-view-method, + next-method-called, + metaclass-assignment, + indexing-exception, + raising-string, + reload-builtin, + oct-method, + hex-method, + nonzero-method, + cmp-method, + input-builtin, + round-builtin, + intern-builtin, + unichr-builtin, + map-builtin-not-iterating, + zip-builtin-not-iterating, + range-builtin-not-iterating, + filter-builtin-not-iterating, + using-cmp-argument, + eq-without-hash, + div-method, + idiv-method, + rdiv-method, + exception-message-attribute, + invalid-str-codec, + sys-max-int, + bad-python3-import, + deprecated-string-function, + deprecated-str-translate-call, + deprecated-itertools-function, + deprecated-types-field, + next-method-defined, + dict-items-not-iterating, + dict-keys-not-iterating, + dict-values-not-iterating, + deprecated-operator-function, + deprecated-urllib-function, + xreadlines-attribute, + deprecated-sys-function, + exception-escape, + comprehension-escape + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +#variable-rgx= + + +[LOGGING] + +# Format style used to check logging format string. `old` means using % +# formatting, while `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package.. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check=trailing-comma, + dict-separator + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[STRING] + +# This flag controls whether the implicit-str-concat-in-sequence should +# generate a warning on implicit string concatenation in sequences defined over +# several lines. +check-str-concat-over-line-jumps=no + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement. +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[IMPORTS] + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception diff --git a/docs/full/lib.image.rst b/docs/full/lib.image.rst new file mode 100644 index 0000000..36ba2e5 --- /dev/null +++ b/docs/full/lib.image.rst @@ -0,0 +1,7 @@ +lib.image module +======================== + +.. automodule:: lib.image + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/full/lib.rst b/docs/full/lib.rst index a4726c5..44ca917 100644 --- a/docs/full/lib.rst +++ b/docs/full/lib.rst @@ -8,6 +8,8 @@ Subpackages lib.model lib.faces_detect + lib.image + lib.training_data Module contents --------------- diff --git a/docs/full/lib.training_data.rst b/docs/full/lib.training_data.rst new file mode 100644 index 0000000..6865234 --- /dev/null +++ b/docs/full/lib.training_data.rst @@ -0,0 +1,7 @@ +lib.training\_data module +========================= + +.. automodule:: lib.training_data + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/index.rst b/docs/index.rst index 511d36b..ca88bba 100755 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,7 +7,7 @@ faceswap.dev Developer Documentation ==================================== .. toctree:: - :maxdepth: 4 + :maxdepth: 2 :caption: Contents: full/modules diff --git a/lib/alignments.py b/lib/alignments.py index 8717947..def51d8 100644 --- a/lib/alignments.py +++ b/lib/alignments.py @@ -8,8 +8,8 @@ from datetime import datetime import cv2 +from lib.faces_detect import rotate_landmarks from lib import Serializer -from lib.utils import rotate_landmarks logger = logging.getLogger(__name__) # pylint: disable=invalid-name diff --git a/lib/face_filter.py b/lib/face_filter.py index 3191971..36ef194 100644 --- a/lib/face_filter.py +++ b/lib/face_filter.py @@ -4,7 +4,7 @@ import logging from lib.vgg_face import VGGFace -from lib.utils import cv2_read_img +from lib.image import read_image from plugins.extract.pipeline import Extractor logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -47,10 +47,10 @@ class FaceFilter(): """ Load the images """ retval = dict() for fpath in reference_file_paths: - retval[fpath] = {"image": cv2_read_img(fpath, raise_error=True), + retval[fpath] = {"image": read_image(fpath, raise_error=True), "type": "filter"} for fpath in nreference_file_paths: - retval[fpath] = {"image": cv2_read_img(fpath, raise_error=True), + retval[fpath] = {"image": read_image(fpath, raise_error=True), "type": "nfilter"} logger.debug("Loaded filter images: %s", {k: v["type"] for k, v in retval.items()}) return retval diff --git a/lib/faces_detect.py b/lib/faces_detect.py index b712988..6525af6 100644 --- a/lib/faces_detect.py +++ b/lib/faces_detect.py @@ -2,6 +2,7 @@ """ Face and landmarks detection for faceswap.py """ import logging +import cv2 import numpy as np from lib.aligner import Extract as AlignerExtract, get_align_mat, get_matrix_scaling @@ -399,3 +400,89 @@ class DetectedFace(): if not self.reference: return None return get_matrix_scaling(self.reference_matrix) + + +def rotate_landmarks(face, rotation_matrix): + """ Rotates the 68 point landmarks and detection bounding box around the given rotation matrix. + + Paramaters + ---------- + face: DetectedFace or dict + A :class:`DetectedFace` or an `alignments file` ``dict`` containing the 68 point landmarks + and the `x`, `w`, `y`, `h` detection bounding box points. + rotation_matrix: numpy.ndarray + The rotation matrix to rotate the given object by. + + Returns + ------- + DetectedFace or dict + The rotated :class:`DetectedFace` or `alignments file` ``dict`` with the landmarks and + detection bounding box points rotated by the given matrix. The return type is the same as + the input type for ``face`` + """ + logger.trace("Rotating landmarks: (rotation_matrix: %s, type(face): %s", + rotation_matrix, type(face)) + rotated_landmarks = None + # Detected Face Object + if isinstance(face, DetectedFace): + bounding_box = [[face.x, face.y], + [face.x + face.w, face.y], + [face.x + face.w, face.y + face.h], + [face.x, face.y + face.h]] + landmarks = face.landmarks_xy + + # Alignments Dict + elif isinstance(face, dict) and "x" in face: + bounding_box = [[face.get("x", 0), face.get("y", 0)], + [face.get("x", 0) + face.get("w", 0), + face.get("y", 0)], + [face.get("x", 0) + face.get("w", 0), + face.get("y", 0) + face.get("h", 0)], + [face.get("x", 0), + face.get("y", 0) + face.get("h", 0)]] + landmarks = face.get("landmarks_xy", list()) + + else: + raise ValueError("Unsupported face type") + + logger.trace("Original landmarks: %s", landmarks) + + rotation_matrix = cv2.invertAffineTransform( + rotation_matrix) + rotated = list() + for item in (bounding_box, landmarks): + if not item: + continue + points = np.array(item, np.int32) + points = np.expand_dims(points, axis=0) + transformed = cv2.transform(points, + rotation_matrix).astype(np.int32) + rotated.append(transformed.squeeze()) + + # Bounding box should follow x, y planes, so get min/max + # for non-90 degree rotations + pt_x = min([pnt[0] for pnt in rotated[0]]) + pt_y = min([pnt[1] for pnt in rotated[0]]) + pt_x1 = max([pnt[0] for pnt in rotated[0]]) + pt_y1 = max([pnt[1] for pnt in rotated[0]]) + width = pt_x1 - pt_x + height = pt_y1 - pt_y + + if isinstance(face, DetectedFace): + face.x = int(pt_x) + face.y = int(pt_y) + face.w = int(width) + face.h = int(height) + face.r = 0 + if len(rotated) > 1: + rotated_landmarks = [tuple(point) for point in rotated[1].tolist()] + face.landmarks_xy = rotated_landmarks + else: + face["left"] = int(pt_x) + face["top"] = int(pt_y) + face["right"] = int(pt_x1) + face["bottom"] = int(pt_y1) + rotated_landmarks = face + + logger.trace("Rotated landmarks: %s", rotated_landmarks) + return face diff --git a/lib/image.py b/lib/image.py new file mode 100644 index 0000000..3b3e99e --- /dev/null +++ b/lib/image.py @@ -0,0 +1,302 @@ +#!/usr/bin python3 +""" Utilities for working with images and videos """ + +import logging +import subprocess +import sys + +from concurrent import futures +from hashlib import sha1 + +import cv2 +import imageio_ffmpeg as im_ffm +import numpy as np + +from lib.utils import convert_to_secs, FaceswapError + +logger = logging.getLogger(__name__) # pylint:disable=invalid-name + +# ################### # +# <<< IMAGE UTILS >>> # +# ################### # + + +# <<< IMAGE IO >>> # + +def read_image(filename, raise_error=False): + """ Read an image file from a file location. + + Extends the functionality of :func:`cv2.imread()` by ensuring that an image was actually + loaded. Errors can be logged and ignored so that the process can continue on an image load + failure. + + Parameters + ---------- + filename: str + Full path to the image to be loaded. + raise_error: bool, optional + If ``True``, then any failures (including the returned image being ``None``) will be + raised. If ``False`` then an error message will be logged, but the error will not be + raised. Default: ``False`` + + Returns + ------- + numpy.ndarray + The image in `BGR` channel order. + + Example + ------- + >>> image_file = "/path/to/image.png" + >>> try: + >>> image = read_image(image_file, raise_error=True) + >>> except: + >>> raise ValueError("There was an error") + """ + logger.trace("Requested image: '%s'", filename) + success = True + image = None + try: + image = cv2.imread(filename) + if image is None: + raise ValueError + except TypeError: + success = False + msg = "Error while reading image (TypeError): '{}'".format(filename) + logger.error(msg) + if raise_error: + raise Exception(msg) + except ValueError: + success = False + msg = ("Error while reading image. This is most likely caused by special characters in " + "the filename: '{}'".format(filename)) + logger.error(msg) + if raise_error: + raise Exception(msg) + except Exception as err: # pylint:disable=broad-except + success = False + msg = "Failed to load image '{}'. Original Error: {}".format(filename, str(err)) + logger.error(msg) + if raise_error: + raise Exception(msg) + logger.trace("Loaded image: '%s'. Success: %s", filename, success) + return image + + +def read_image_batch(filenames): + """ Load a batch of images from the given file locations. + + Leverages multi-threading to load multiple images from disk at the same time + leading to vastly reduced image read times. + + Parameters + ---------- + filenames: list + A list of ``str`` full paths to the images to be loaded. + + Returns + ------- + numpy.ndarray + The batch of images in `BGR` channel order. + + Notes + ----- + As the images are compiled into a batch, they must be all of the same dimensions. + + Example + ------- + >>> image_filenames = ["/path/to/image_1.png", "/path/to/image_2.png", "/path/to/image_3.png"] + >>> images = read_image_batch(image_filenames) + """ + logger.trace("Requested batch: '%s'", filenames) + executor = futures.ThreadPoolExecutor() + with executor: + images = [executor.submit(read_image, filename, raise_error=True) + for filename in filenames] + batch = np.array([future.result() for future in futures.as_completed(images)]) + logger.trace("Returning images: %s", batch.shape) + return batch + + +def read_image_hash(filename): + """ Return the `sha1` hash of an image saved on disk. + + Parameters + ---------- + filename: str + Full path to the image to be loaded. + + Returns + ------- + str + The :func:`hashlib.hexdigest()` representation of the `sha1` hash of the given image. + Example + ------- + >>> image_file = "/path/to/image.png" + >>> image_hash = read_image_hash(image_file) + """ + img = read_image(filename, raise_error=True) + image_hash = sha1(img).hexdigest() + logger.trace("filename: '%s', hash: %s", filename, image_hash) + return image_hash + + +def encode_image_with_hash(image, extension): + """ Encode an image, and get the encoded image back with its `sha1` hash. + + Parameters + ---------- + image: numpy.ndarray + The image to be encoded in `BGR` channel order. + extension: str + A compatible `cv2` image file extension that the final image is to be saved to. + + Returns + ------- + image_hash: str + The :func:`hashlib.hexdigest()` representation of the `sha1` hash of the encoded image + encoded_image: bytes + The image encoded into the correct file format + + Example + ------- + >>> image_file = "/path/to/image.png" + >>> image = read_image(image_file) + >>> image_hash, encoded_image = encode_image_with_hash(image, ".jpg") + """ + encoded_image = cv2.imencode(extension, image)[1] + image_hash = sha1(cv2.imdecode(encoded_image, cv2.IMREAD_UNCHANGED)).hexdigest() + return image_hash, encoded_image + + +def batch_convert_color(batch, colorspace): + """ Convert a batch of images from one color space to another. + + Converts a batch of images by reshaping the batch prior to conversion rather than iterating + over the images. This leads to a significant speed up in the convert process. + + Parameters + ---------- + batch: numpy.ndarray + A batch of images. + colorspace: str + The OpenCV Color Conversion Code suffix. For example for BGR to LAB this would be + ``'BGR2LAB'``. + See https://docs.opencv.org/4.1.1/d8/d01/group__imgproc__color__conversions.html for a full + list of color codes. + + Returns + ------- + numpy.ndarray + The batch converted to the requested color space. + + Example + ------- + >>> images_bgr = numpy.array([image1, image2, image3]) + >>> images_lab = batch_convert_color(images_bgr, "BGR2LAB") + + Notes + ----- + This function is only compatible for color space conversions that have the same image shape + for source and destination color spaces. + + If you use :func:`batch_convert_color` with 8-bit images, the conversion will have some + information lost. For many cases, this will not be noticeable but it is recommended + to use 32-bit images in cases that need the full range of colors or that convert an image + before an operation and then convert back. + """ + logger.trace("Batch converting: (batch shape: %s, colorspace: %s)", batch.shape, colorspace) + original_shape = batch.shape + batch = batch.reshape((original_shape[0] * original_shape[1], *original_shape[2:])) + batch = cv2.cvtColor(batch, getattr(cv2, "COLOR_{}".format(colorspace))) + return batch.reshape(original_shape) + + +# ################### # +# <<< VIDEO UTILS >>> # +# ################### # + +def count_frames_and_secs(filename, timeout=60): + """ Count the number of frames and seconds in a video file. + + Adapted From :mod:`ffmpeg_imageio` to handle the issue of ffmpeg occasionally hanging + inside a subprocess. + + If the operation times out then the process will try to read the data again, up to a total + of 3 times. If the data still cannot be read then an exception will be raised. + + Note that this operation can be quite slow for large files. + + Parameters + ---------- + filename: str + Full path to the video to be analyzed. + timeout: str, optional + The amount of time in seconds to wait for the video data before aborting. + Default: ``60`` + + Returns + ------- + nframes: int + The number of frames in the given video file. + nsecs: float + The duration, in seconds, of the given video file. + + Example + ------- + >>> video = "/path/to/video.mp4" + >>> frames, secs = count_frames_and_secs(video) + """ + # https://stackoverflow.com/questions/2017843/fetch-frame-count-with-ffmpeg + + assert isinstance(filename, str), "Video path must be a string" + exe = im_ffm.get_ffmpeg_exe() + iswin = sys.platform.startswith("win") + logger.debug("iswin: '%s'", iswin) + cmd = [exe, "-i", filename, "-map", "0:v:0", "-c", "copy", "-f", "null", "-"] + logger.debug("FFMPEG Command: '%s'", " ".join(cmd)) + attempts = 3 + for attempt in range(attempts): + try: + logger.debug("attempt: %s of %s", attempt + 1, attempts) + out = subprocess.check_output(cmd, + stderr=subprocess.STDOUT, + shell=iswin, + timeout=timeout) + logger.debug("Succesfully communicated with FFMPEG") + break + except subprocess.CalledProcessError as err: + out = err.output.decode(errors="ignore") + raise RuntimeError("FFMEG call failed with {}:\n{}".format(err.returncode, out)) + except subprocess.TimeoutExpired as err: + this_attempt = attempt + 1 + if this_attempt == attempts: + msg = ("FFMPEG hung while attempting to obtain the frame count. " + "Sometimes this issue resolves itself, so you can try running again. " + "Otherwise use the Effmpeg Tool to extract the frames from your video into " + "a folder, and then run the requested Faceswap process on that folder.") + raise FaceswapError(msg) from err + logger.warning("FFMPEG hung while attempting to obtain the frame count. " + "Retrying %s of %s", this_attempt + 1, attempts) + continue + + # Note that other than with the subprocess calls below, ffmpeg wont hang here. + # Worst case Python will stop/crash and ffmpeg will continue running until done. + + nframes = nsecs = None + for line in reversed(out.splitlines()): + if not line.startswith(b"frame="): + continue + line = line.decode(errors="ignore") + logger.debug("frame line: '%s'", line) + idx = line.find("frame=") + if idx >= 0: + splitframes = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip() + nframes = int(splitframes) + idx = line.find("time=") + if idx >= 0: + splittime = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip() + nsecs = convert_to_secs(*splittime.split(":")) + logger.debug("nframes: %s, nsecs: %s", nframes, nsecs) + return nframes, nsecs + + raise RuntimeError("Could not get number of frames") # pragma: no cover diff --git a/lib/model/masks.py b/lib/model/masks.py index cb41bf7..d7c0d68 100644 --- a/lib/model/masks.py +++ b/lib/model/masks.py @@ -43,6 +43,8 @@ class Mask(): self.__class__.__name__, face.shape, channels, landmarks) self.landmarks = landmarks self.face = face + self.dtype = face.dtype + self.threshold = 255 if self.dtype == "uint8" else 255.0 self.channels = channels mask = self.build_mask() @@ -73,7 +75,7 @@ class Mask(): class dfl_full(Mask): # pylint: disable=invalid-name """ DFL facial mask """ def build_mask(self): - mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype) nose_ridge = (self.landmarks[27:31], self.landmarks[33:34]) jaw = (self.landmarks[0:17], @@ -90,14 +92,14 @@ class dfl_full(Mask): # pylint: disable=invalid-name for item in parts: merged = np.concatenate(item) - cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member + cv2.fillConvexPoly(mask, cv2.convexHull(merged), self.threshold) return mask class components(Mask): # pylint: disable=invalid-name """ Component model mask """ def build_mask(self): - mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype) r_jaw = (self.landmarks[0:9], self.landmarks[17:18]) l_jaw = (self.landmarks[8:17], self.landmarks[26:27]) @@ -117,7 +119,7 @@ class components(Mask): # pylint: disable=invalid-name for item in parts: merged = np.concatenate(item) - cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member + cv2.fillConvexPoly(mask, cv2.convexHull(merged), self.threshold) return mask @@ -126,7 +128,7 @@ class extended(Mask): # pylint: disable=invalid-name Based on components mask. Attempts to extend the eyebrow points up the forehead """ def build_mask(self): - mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype) landmarks = self.landmarks.copy() # mid points between the side of face and eye point @@ -161,15 +163,15 @@ class extended(Mask): # pylint: disable=invalid-name for item in parts: merged = np.concatenate(item) - cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member + cv2.fillConvexPoly(mask, cv2.convexHull(merged), self.threshold) return mask class facehull(Mask): # pylint: disable=invalid-name """ Basic face hull mask """ def build_mask(self): - mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) - hull = cv2.convexHull( # pylint: disable=no-member + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype) + hull = cv2.convexHull( np.array(self.landmarks).reshape((-1, 2))) - cv2.fillConvexPoly(mask, hull, 255.0, lineType=cv2.LINE_AA) # pylint: disable=no-member + cv2.fillConvexPoly(mask, hull, self.threshold, lineType=cv2.LINE_AA) return mask diff --git a/lib/training_data.py b/lib/training_data.py index 423f744..2c3c67a 100644 --- a/lib/training_data.py +++ b/lib/training_data.py @@ -1,88 +1,191 @@ #!/usr/bin/env python3 -""" Process training data for model training """ +""" Handles Data Augmentation for feeding Faceswap Models """ import logging from hashlib import sha1 -from random import random, shuffle, choice +from random import shuffle, choice import numpy as np import cv2 from scipy.interpolate import griddata +from lib.image import batch_convert_color, read_image_batch from lib.model import masks from lib.multithreading import BackgroundGenerator -from lib.queue_manager import queue_manager -from lib.umeyama import umeyama -from lib.utils import cv2_read_img, FaceswapError +from lib.utils import FaceswapError logger = logging.getLogger(__name__) # pylint: disable=invalid-name class TrainingDataGenerator(): - """ Generate training data for models """ + """ A Training Data Generator for compiling data for feeding to a model. + + This class is called from :mod:`plugins.train.trainer._base` and launches a background + iterator that compiles augmented data, target data and sample data. + + Parameters + ---------- + model_input_size: int + The expected input size for the model. It is assumed that the input to the model is always + a square image. This is the size, in pixels, of the `width` and the `height` of the input + to the model. + model_output_shapes: list + A list of tuples defining the output shapes from the model, in the order that the outputs + are returned. The tuples should be in (`height`, `width`, `channels`) format. + training_opts: dict + This is a dictionary of model training options as defined in + :mod:`plugins.train.model._base`. These options will be defined by the user from the + provided cli options or from the model ``config.ini``. At a minimum this ``dict`` should + contain the following keys: + + * **coverage_ratio** (`float`) - The ratio of the training image to be trained on. \ + Dictates how much of the image will be cropped out. Eg: a coverage ratio of 0.625 \ + will result in cropping a 160px box from a 256px image (256 * 0.625 = 160). + + * **augment_color** (`bool`) - ``True`` if color is to be augmented, otherwise ``False`` \ + + * **no_flip** (`bool`) - ``True`` if the image shouldn't be randomly flipped as part of \ + augmentation, otherwise ``False`` + + * **mask_type** (`str`) - The mask type to be used (as defined in \ + :mod:`lib.model.masks`). If not ``None`` then the additional key ``landmarks`` must be \ + provided. + + * **warp_to_landmarks** (`bool`) - ``True`` if the random warp method should warp to \ + similar landmarks from the other side, ``False`` if the standard random warp method \ + should be used. If ``True`` then the additional key ``landmarks`` must be provided. + + * **landmarks** (`numpy.ndarray`, `optional`). Required if using a :attr:`mask_type` is \ + not ``None`` or :attr:`warp_to_landmarks` is ``True``. The 68 point face landmarks from \ + an alignments file. + + config: dict + The configuration ``dict`` generated from :file:`config.train.ini` containing the trainer \ + plugin configuration options. + """ def __init__(self, model_input_size, model_output_shapes, training_opts, config): logger.debug("Initializing %s: (model_input_size: %s, model_output_shapes: %s, " "training_opts: %s, landmarks: %s, config: %s)", self.__class__.__name__, model_input_size, model_output_shapes, {key: val for key, val in training_opts.items() if key != "landmarks"}, bool(training_opts.get("landmarks", None)), config) - self.batchsize = 0 - self.model_input_size = model_input_size - self.model_output_shapes = model_output_shapes - self.training_opts = training_opts - self.mask_class = self.set_mask_class() - self.landmarks = self.training_opts.get("landmarks", None) + self._config = config + self._model_input_size = model_input_size + self._model_output_shapes = model_output_shapes + self._training_opts = training_opts + self._mask_class = self._set__mask_class() + self._landmarks = self._training_opts.get("landmarks", None) self._nearest_landmarks = {} - self.processing = ImageManipulation(model_input_size, - model_output_shapes, - training_opts.get("coverage_ratio", 0.625), - config) - logger.debug("Initialized %s", self.__class__.__name__) - def set_mask_class(self): - """ Set the mask function to use if using mask """ - mask_type = self.training_opts.get("mask_type", None) - if mask_type: - logger.debug("Mask type: '%s'", mask_type) - mask_class = getattr(masks, mask_type) - else: - mask_class = None - logger.debug("Mask class: %s", mask_class) - return mask_class + # Batchsize and processing class are set when this class is called by a batcher + # from lib.training_data + self._batchsize = 0 + self._processing = None + logger.debug("Initialized %s", self.__class__.__name__) def minibatch_ab(self, images, batchsize, side, do_shuffle=True, is_preview=False, is_timelapse=False): - """ Keep a queue filled to 8x Batch Size """ + """ A Background iterator to return augmented images, samples and targets. + + The exit point from this class and the sole attribute that should be referenced. Called + from :mod:`plugins.train.trainer._base`. Returns an iterator that yields images for + training, preview and timelapses. + + Parameters + ---------- + images: list + A list of image paths that will be used to compile the final augmented data from. + batchsize: int + The batchsize for this iterator. Images will be returned in ``numpy.ndarray`` s of + this size from the iterator. + side: {'a' or 'b'} + The side of the model that this iterator is for. + do_shuffle: bool, optional + Whether data should be shuffled prior to loading from disk. If true, each time the full + list of filenames are processed, the data will be reshuffled to make sure thay are not + returned in the same order. Default: ``True`` + is_preview: bool, optional + Indicates whether this iterator is generating preview images. If ``True`` then certain + augmentations will not be performed. Default: ``False`` + is_timelapse: bool optional + Indicates whether this iterator is generating Timelapse images. If ``True``, then + certain augmentations will not be performed. Default: ``False`` + + Yields + ------ + dict + The following items are contained in each ``dict`` yielded from this iterator: + + * **feed** (`numpy.ndarray`) - The feed for the model. The array returned is in the \ + format (`batchsize`, `height`, `width`, `channels`). This is the :attr:`x` parameter \ + for :func:`keras.models.model.train_on_batch`. + + * **targets** (`list`) - A list of 4-dimensional ``numpy.ndarray`` s in the order \ + and size of each output of the model as defined in :attr:`model_output_shapes`. the \ + format of these arrays will be (`batchsize`, `height`, `width`, `3`). This is \ + the :attr:`y` parameter for :func:`keras.models.model.train_on_batch` **NB:** \ + masks are not included in the ``targets`` list. If required for feeding into the \ + Keras model, they will need to be added to this list in \ + :mod:`plugins.train.trainer._base` from the ``masks`` key. + + * **masks** (`numpy.ndarray`) - A 4-dimensional array containing the target masks in \ + the format (`batchsize`, `height`, `width`, `1`). **NB:** This item will only exist \ + in the ``dict`` if the :attr:`mask_type` is not ``None`` + + * **samples** (`numpy.ndarray`) - A 4-dimensional array containg the samples for \ + feeding to the model's predict function for generating preview and timelapse samples. \ + The array will be in the format (`batchsize`, `height`, `width`, `channels`). **NB:** \ + This item will only exist in the ``dict`` if :attr:`is_preview` or \ + :attr:`is_timelapse` is ``True`` + """ logger.debug("Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, " "is_preview, %s, is_timelapse: %s)", len(images), batchsize, side, do_shuffle, is_preview, is_timelapse) - self.batchsize = batchsize - is_display = is_preview or is_timelapse - args = (images, side, is_display, do_shuffle, batchsize) - batcher = BackgroundGenerator(self.minibatch, thread_count=2, args=args) + self._batchsize = batchsize + self._processing = ImageAugmentation(batchsize, + is_preview or is_timelapse, + self._model_input_size, + self._model_output_shapes, + self._training_opts.get("coverage_ratio", 0.625), + self._config) + args = (images, side, do_shuffle, batchsize) + batcher = BackgroundGenerator(self._minibatch, thread_count=2, args=args) return batcher.iterator() - def validate_samples(self, data): - """ Check the total number of images against batchsize and return - the total number of images """ + # << INTERNAL METHODS >> # + def _set__mask_class(self): + """ Returns the correct mask class from :mod:`lib`.model.masks` as defined in the + :attr:`mask_type` parameter. """ + mask_type = self._training_opts.get("mask_type", None) + if mask_type: + logger.debug("Mask type: '%s'", mask_type) + _mask_class = getattr(masks, mask_type) + else: + _mask_class = None + logger.debug("Mask class: %s", _mask_class) + return _mask_class + + def _validate_samples(self, data): + """ Ensures that the total number of images within :attr:`images` is greater or equal to + the selected :attr:`batchsize`. Raises an exception if this is not the case. """ length = len(data) msg = ("Number of images is lower than batch-size (Note that too few " "images may lead to bad training). # images: {}, " - "batch-size: {}".format(length, self.batchsize)) + "batch-size: {}".format(length, self._batchsize)) try: - assert length >= self.batchsize, msg + assert length >= self._batchsize, msg except AssertionError as err: msg += ("\nYou should increase the number of images in your training set or lower " "your batch-size.") raise FaceswapError(msg) from err - def minibatch(self, images, side, is_display, do_shuffle, batchsize): - """ A generator function that yields epoch, batchsize of warped_img - and batchsize of target_img from the load queue """ - logger.debug("Loading minibatch generator: (image_count: %s, side: '%s', is_display: %s, " - "do_shuffle: %s)", len(images), side, is_display, do_shuffle) - self.validate_samples(images) + def _minibatch(self, images, side, do_shuffle, batchsize): + """ A generator function that yields the augmented, target and sample images. + see :func:`minibatch_ab` for more details on the output. """ + logger.debug("Loading minibatch generator: (image_count: %s, side: '%s', do_shuffle: %s)", + len(images), side, do_shuffle) + self._validate_samples(images) def _img_iter(imgs): while True: @@ -93,369 +196,530 @@ class TrainingDataGenerator(): img_iter = _img_iter(images) while True: - batch = list() - for _ in range(batchsize): - img_path = next(img_iter) - data = self.process_face(img_path, side, is_display) - batch.append(data) - batch = list(zip(*batch)) - batch = [np.array(x, dtype="float32") for x in batch] - logger.trace("Yielding batch: (size: %s, item shapes: %s, side: '%s', " - "is_display: %s)", - len(batch), [item.shape for item in batch], side, is_display) - yield batch + img_paths = [next(img_iter) for _ in range(batchsize)] + yield self._process_batch(img_paths, side) - logger.debug("Finished minibatch generator: (side: '%s', is_display: %s)", - side, is_display) + logger.debug("Finished minibatch generator: (side: '%s')", side) - def process_face(self, filename, side, is_display): - """ Load an image and perform transformation and warping """ - logger.trace("Process face: (filename: '%s', side: '%s', is_display: %s)", - filename, side, is_display) - image = cv2_read_img(filename, raise_error=True) - if self.mask_class or self.training_opts["warp_to_landmarks"]: - src_pts = self.get_landmarks(filename, image, side) - if self.mask_class: - image = self.mask_class(src_pts, image, channels=4).mask + def _process_batch(self, filenames, side): + """ Performs the augmentation and compiles target images and samples. See + :func:`minibatch_ab` for more details on the output. """ + logger.trace("Process batch: (filenames: '%s', side: '%s')", filenames, side) + batch = read_image_batch(filenames) + processed = dict() + to_landmarks = self._training_opts["warp_to_landmarks"] - image = self.processing.color_adjust(image, - self.training_opts["augment_color"], - is_display) - if not is_display: - image = self.processing.random_transform(image) - if not self.training_opts["no_flip"]: - image = self.processing.do_random_flip(image) - sample = image.copy()[:, :, :3] + # Initialize processing training size on first image + if not self._processing.initialized: + self._processing.initialize(batch.shape[1]) - if self.training_opts["warp_to_landmarks"]: - dst_pts = self.get_closest_match(filename, side, src_pts) - processed = self.processing.random_warp_landmarks(image, src_pts, dst_pts) + # Get Landmarks prior to manipulating the image + if self._mask_class or to_landmarks: + batch_src_pts = self._get_landmarks(filenames, batch, side) + + # Color augmentation before mask is added + if self._training_opts["augment_color"]: + batch = self._processing.color_adjust(batch) + + # Add mask to batch prior to transforms and warps + if self._mask_class: + batch = np.array([self._mask_class(src_pts, image, channels=4).mask + for src_pts, image in zip(batch_src_pts, batch)]) + + # Random Transform and flip + batch = self._processing.transform(batch) + if not self._training_opts["no_flip"]: + batch = self._processing.random_flip(batch) + + # Add samples to output if this is for display + if self._processing.is_display: + processed["samples"] = batch[..., :3].astype("float32") / 255.0 + + # Get Targets + processed.update(self._processing.get_targets(batch)) + + # Random Warp + if to_landmarks: + warp_kwargs = dict(batch_src_points=batch_src_pts, + batch_dst_points=self._get_closest_match(filenames, + side, + batch_src_pts)) else: - processed = self.processing.random_warp(image) + warp_kwargs = dict() + processed["feed"] = self._processing.warp(batch[..., :3], to_landmarks, **warp_kwargs) + + logger.trace("Processed batch: (filenames: %s, side: '%s', processed: %s)", + filenames, + side, + {k: v.shape if isinstance(v, np.ndarray) else[i.shape for i in v] + for k, v in processed.items()}) - processed.insert(0, sample) - logger.trace("Processed face: (filename: '%s', side: '%s', shapes: %s)", - filename, side, [img.shape for img in processed]) return processed - def get_landmarks(self, filename, image, side): - """ Return the landmarks for this face """ - logger.trace("Retrieving landmarks: (filename: '%s', side: '%s'", filename, side) - lm_key = sha1(image).hexdigest() - try: - src_points = self.landmarks[side][lm_key] - except KeyError as err: - msg = ("At least one of your images does not have a matching entry in your alignments " - "file." + def _get_landmarks(self, filenames, batch, side): + """ Obtains the 68 Point Landmarks for the images in this batch. This is only called if + config item ``warp_to_landmarks`` is ``True`` or if :attr:`mask_type` is not ``None``. If + the landmarks for an image cannot be found, then an error is raised. """ + logger.trace("Retrieving landmarks: (filenames: '%s', side: '%s'", filenames, side) + src_points = [self._landmarks[side].get(sha1(face).hexdigest(), None) for face in batch] + + # Raise error on missing alignments + if not all(isinstance(pts, np.ndarray) for pts in src_points): + indices = [idx for idx, hsh in enumerate(src_points) if hsh is None] + missing = [filenames[idx] for idx in indices] + msg = ("Files missing alignments for this batch: {}" + "\nAt least one of your images does not have a matching entry in your " + "alignments file." "\nIf you are training with a mask or using 'warp to landmarks' then every " "face you intend to train on must exist within the alignments file." - "\nThe specific file that caused the failure was '{}' which has a hash of {}." - "\nMost likely there will be more than just this file missing from the " + "\nThe specific files that caused this failure are listed above." + "\nMost likely there will be more than just these files missing from the " "alignments file. You can use the Alignments Tool to help identify missing " - "alignments".format(lm_key, filename)) - raise FaceswapError(msg) from err + "alignments".format(missing)) + raise FaceswapError(msg) + logger.trace("Returning: (src_points: %s)", src_points) - return src_points + return np.array(src_points) - def get_closest_match(self, filename, side, src_points): - """ Return closest matched landmarks from opposite set """ - logger.trace("Retrieving closest matched landmarks: (filename: '%s', src_points: '%s'", - filename, src_points) - landmarks = self.landmarks["a"] if side == "b" else self.landmarks["b"] - closest_hashes = self._nearest_landmarks.get(filename) - if not closest_hashes: - dst_points_items = list(landmarks.items()) - dst_points = list(x[1] for x in dst_points_items) + def _get_closest_match(self, filenames, side, batch_src_points): + """ Only called if the config item ``warp_to_landmarks`` is ``True``. Gets the closest + matched 68 point landmarks from the opposite training set. """ + logger.trace("Retrieving closest matched landmarks: (filenames: '%s', src_points: '%s'", + filenames, batch_src_points) + landmarks = self._landmarks["a"] if side == "b" else self._landmarks["b"] + closest_hashes = [self._nearest_landmarks.get(filename) for filename in filenames] + if None in closest_hashes: + closest_hashes = self._cache_closest_hashes(filenames, batch_src_points, landmarks) + + batch_dst_points = np.array([landmarks[choice(hsh)] for hsh in closest_hashes]) + logger.trace("Returning: (batch_dst_points: %s)", batch_dst_points.shape) + return batch_dst_points + + def _cache_closest_hashes(self, filenames, batch_src_points, landmarks): + """ Cache the nearest landmarks for this batch """ + logger.trace("Caching closest hashes") + dst_landmarks = list(landmarks.items()) + dst_points = np.array([lm[1] for lm in dst_landmarks]) + batch_closest_hashes = list() + + for filename, src_points in zip(filenames, batch_src_points): closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10] - closest_hashes = tuple(dst_points_items[i][0] for i in closest) + closest_hashes = tuple(dst_landmarks[i][0] for i in closest) self._nearest_landmarks[filename] = closest_hashes - dst_points = landmarks[choice(closest_hashes)] - logger.trace("Returning: (dst_points: %s)", dst_points) - return dst_points + batch_closest_hashes.append(closest_hashes) + logger.trace("Cached closest hashes") + return batch_closest_hashes -class ImageManipulation(): - """ Manipulations to be performed on training images """ - def __init__(self, input_size, output_shapes, coverage_ratio, config): - """ input_size: Size of the face input into the model - output_shapes: Shapes that come out of the model - coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160 - """ - logger.debug("Initializing %s: (input_size: %s, output_shapes: %s, coverage_ratio: %s, " - "config: %s)", self.__class__.__name__, input_size, output_shapes, +class ImageAugmentation(): + """ Performs augmentation on batches of training images. + + Parameters + ---------- + batchsize: int + The number of images that will be fed through the augmentation functions at once. + is_display: bool + Whether the images being fed through will be used for Preview or Timelapse. Disables + the "warp" augmentation for these images. + input_size: int + The expected input size for the model. It is assumed that the input to the model is always + a square image. This is the size, in pixels, of the `width` and the `height` of the input + to the model. + output_shapes: list + A list of tuples defining the output shapes from the model, in the order that the outputs + are returned. The tuples should be in (`height`, `width`, `channels`) format. + coverage_ratio: float + The ratio of the training image to be trained on. Dictates how much of the image will be + cropped out. Eg: a coverage ratio of 0.625 will result in cropping a 160px box from a 256px + image (256 * 0.625 = 160). + config: dict + The configuration ``dict`` generated from :file:`config.train.ini` containing the trainer \ + plugin configuration options. + + Attributes + ---------- + initialized: bool + Flag to indicate whether :class:`ImageAugmentation` has been initialized with the training + image size in order to cache certain augmentation operations (see :func:`initialize`) + is_display: bool + Flag to indicate whether these augmentations are for timelapses/preview images (``True``) + or standard training data (``False)`` + """ + def __init__(self, batchsize, is_display, input_size, output_shapes, coverage_ratio, config): + logger.debug("Initializing %s: (batchsize: %s, is_display: %s, input_size: %s, " + "output_shapes: %s, coverage_ratio: %s, config: %s)", + self.__class__.__name__, batchsize, is_display, input_size, output_shapes, coverage_ratio, config) - self.config = config + + self.initialized = False + self.is_display = is_display + + # Set on first image load from initialize + self._training_size = 0 + self._constants = None + + self._batchsize = batchsize + self._config = config # Transform and Warp args - self.input_size = input_size - self.output_sizes = [shape[1] for shape in output_shapes if shape[2] == 3] - logger.debug("Output sizes: %s", self.output_sizes) + self._input_size = input_size + self._output_sizes = [shape[1] for shape in output_shapes if shape[2] == 3] + logger.debug("Output sizes: %s", self._output_sizes) # Warp args - self.coverage_ratio = coverage_ratio # Coverage ratio of full image. Eg: 256 * 0.625 = 160 - self.scale = 5 # Normal random variable scale + self._coverage_ratio = coverage_ratio + self._scale = 5 # Normal random variable scale + logger.debug("Initialized %s", self.__class__.__name__) - def color_adjust(self, img, augment_color, is_display): - """ Color adjust RGB image """ - logger.trace("Color adjusting image") - if not is_display and augment_color: - logger.trace("Augmenting color") - face, _ = self.separate_mask(img) - face = face.astype("uint8") - face = self.random_clahe(face) - face = self.random_lab(face) - img[:, :, :3] = face - return img.astype('float32') / 255.0 + def initialize(self, training_size): + """ Initializes the caching of constants for use in various image augmentations. - def random_clahe(self, image): - """ Randomly perform Contrast Limited Adaptive Histogram Equilization """ - contrast_random = random() - if contrast_random > self.config.get("color_clahe_chance", 50) / 100: - return image + The training image size is not known prior to loading the images from disk and commencing + training, so it cannot be set in the ``__init__`` method. When the first training batch is + loaded this function should be called to initialize the class and perform various + calculations based on this input size to cache certain constants for image augmentation + calculations. - base_contrast = image.shape[0] // 128 - grid_base = random() * self.config.get("color_clahe_max_size", 4) - contrast_adjustment = int(grid_base * (base_contrast / 2)) - grid_size = base_contrast + contrast_adjustment - logger.trace("Adjusting Contrast. Grid Size: %s", grid_size) + Parameters + ---------- + training_size: int + The size of the training images stored on disk that are to be fed into + :class:`ImageAugmentation`. The training images should always be square and of the + same size. This is the size, in pixels, of the `width` and the `height` of the + training images. + """ + logger.debug("Initializing constants. training_size: %s", training_size) + self._training_size = training_size + coverage = int(self._training_size * self._coverage_ratio) - clahe = cv2.createCLAHE(clipLimit=2.0, # pylint: disable=no-member - tileGridSize=(grid_size, grid_size)) - for chan in range(3): - image[:, :, chan] = clahe.apply(image[:, :, chan]) - return image + # Color Aug + clahe_base_contrast = training_size // 128 + # Target Images + tgt_slices = slice(self._training_size // 2 - coverage // 2, + self._training_size // 2 + coverage // 2) - def random_lab(self, image): - """ Perform random color/lightness adjustment in L*a*b* colorspace """ - amount_l = self.config.get("color_lightness", 30) / 100 - amount_ab = self.config.get("color_ab", 8) / 100 + # Random Warp + warp_range_ = np.linspace(self._training_size // 2 - coverage // 2, + self._training_size // 2 + coverage // 2, 5, dtype='float32') + warp_mapx = np.broadcast_to(warp_range_, (self._batchsize, 5, 5)).astype("float32") + warp_mapy = np.broadcast_to(warp_mapx[0].T, (self._batchsize, 5, 5)).astype("float32") - randoms = [(random() * amount_l * 2) - amount_l, # L adjust - (random() * amount_ab * 2) - amount_ab, # A adjust - (random() * amount_ab * 2) - amount_ab] # B adjust + warp_pad = int(1.25 * self._input_size) + warp_slices = slice(warp_pad // 10, -warp_pad // 10) - logger.trace("Random LAB adjustments: %s", randoms) - image = cv2.cvtColor( # pylint:disable=no-member - image, cv2.COLOR_BGR2LAB).astype("float32") / 255.0 # pylint:disable=no-member + # Random Warp Landmarks + p_mx = self._training_size - 1 + p_hf = (self._training_size // 2) - 1 + edge_anchors = np.array([(0, 0), (0, p_mx), (p_mx, p_mx), (p_mx, 0), + (p_hf, 0), (p_hf, p_mx), (p_mx, p_hf), (0, p_hf)]).astype("int32") + edge_anchors = np.broadcast_to(edge_anchors, (self._batchsize, 8, 2)) + grids = np.mgrid[0:p_mx:complex(self._training_size), 0:p_mx:complex(self._training_size)] - for idx, adjustment in enumerate(randoms): - if adjustment >= 0: - image[:, :, idx] = ((1 - image[:, :, idx]) * adjustment) + image[:, :, idx] - else: - image[:, :, idx] = image[:, :, idx] * (1 + adjustment) - image = cv2.cvtColor((image * 255.0).astype("uint8"), # pylint:disable=no-member - cv2.COLOR_LAB2BGR) # pylint:disable=no-member - return image + self._constants = dict(clahe_base_contrast=clahe_base_contrast, + tgt_slices=tgt_slices, + warp_mapx=warp_mapx, + warp_mapy=warp_mapy, + warp_pad=warp_pad, + warp_slices=warp_slices, + warp_lm_edge_anchors=edge_anchors, + warp_lm_grids=grids) + self.initialized = True + logger.debug("Initialized constants: %s", self._constants) + + # <<< TARGET IMAGES >>> # + def get_targets(self, batch): + """ Returns the target images, and masks, if required. + + Parameters + ---------- + batch: numpy.ndarray + This should be a 4-dimensional array of training images in the format (`batchsize`, + `height`, `width`, `channels`). Targets should be requested after performing image + transformations but prior to performing warps. + + Returns + ------- + dict + The following keys will be within the returned dictionary: + + * **targets** (`list`) - A list of 4-dimensional ``numpy.ndarray`` s in the order \ + and size of each output of the model as defined in :attr:`output_shapes`. The \ + format of these arrays will be (`batchsize`, `height`, `width`, `3`). **NB:** \ + masks are not included in the ``targets`` list. If masks are to be included in the \ + output they will be returned as their own item from the ``masks`` key. + + * **masks** (`numpy.ndarray`) - A 4-dimensional array containing the target masks in \ + the format (`batchsize`, `height`, `width`, `1`). **NB:** This item will only exist \ + in the ``dict`` if a batch of 4 channel images has been passed in :attr:`batch` + """ + logger.trace("Compiling targets") + slices = self._constants["tgt_slices"] + target_batch = [np.array([cv2.resize(image[slices, slices, :], + (size, size), + cv2.INTER_AREA) + for image in batch]) + for size in self._output_sizes] + logger.trace("Target image shapes: %s", + [tgt.shape for tgt_images in target_batch for tgt in tgt_images]) + + retval = self._separate_target_mask(target_batch) + logger.trace("Final targets: %s", + {k: v.shape if isinstance(v, np.ndarray) else [img.shape for img in v] + for k, v in retval.items()}) + return retval @staticmethod - def separate_mask(image): - """ Return the image and the mask from a 4 channel image """ - mask = None - if image.shape[2] == 4: - logger.trace("Image contains mask") - mask = np.expand_dims(image[:, :, -1], axis=2) - image = image[:, :, :3] + def _separate_target_mask(batch): + """ Return the batch and the batch of final masks + + Returns the targets as a list of 4-dimensional ``numpy.ndarray`` s of shape (`batchsize`, + `height`, `width`, 3). If the :attr:`batch` is 4 channels, then the masks will be split + from the batch, with the largest output masks being returned in their own item. + """ + batch = [tgt.astype("float32") / 255.0 for tgt in batch] + if all(tgt.shape[-1] == 4 for tgt in batch): + logger.trace("Batch contains mask") + sizes = [item.shape[1] for item in batch] + mask_batch = np.expand_dims(batch[sizes.index(max(sizes))][..., -1], axis=-1) + batch = [item[..., :3] for item in batch] + logger.trace("batch shapes: %s, mask_batch shape: %s", + [tgt.shape for tgt in batch], mask_batch.shape) + retval = dict(targets=batch, masks=mask_batch) else: - logger.trace("Image has no mask") - return image, mask + logger.trace("Batch has no mask") + retval = dict(targets=batch) + return retval - def get_coverage(self, image): - """ Return coverage value for given image """ - coverage = int(image.shape[0] * self.coverage_ratio) - logger.trace("Coverage: %s", coverage) - return coverage + # <<< COLOR AUGMENTATION >>> # + def color_adjust(self, batch): + """ Perform color augmentation on the passed in batch. - def random_transform(self, image): - """ Randomly transform an image """ + The color adjustment parameters are set in :file:`config.train.ini` + + Parameters + ---------- + batch: numpy.ndarray + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `3`) and in `BGR` format. + + Returns + ---------- + numpy.ndarray + A 4-dimensional array of the same shape as :attr:`batch` with color augmentation + applied. + """ + if not self.is_display: + logger.trace("Augmenting color") + batch = batch_convert_color(batch, "BGR2LAB") + batch = self._random_clahe(batch) + batch = self._random_lab(batch) + batch = batch_convert_color(batch, "LAB2BGR") + return batch + + def _random_clahe(self, batch): + """ Randomly perform Contrast Limited Adaptive Histogram Equilization on + a batch of images """ + base_contrast = self._constants["clahe_base_contrast"] + + batch_random = np.random.rand(self._batchsize) + indices = np.where(batch_random > self._config.get("color_clahe_chance", 50) / 100)[0] + + grid_bases = np.rint(np.random.uniform(0, + self._config.get("color_clahe_max_size", 4), + size=indices.shape[0])).astype("uint8") + contrast_adjustment = (grid_bases * (base_contrast // 2)) + grid_sizes = contrast_adjustment + base_contrast + logger.trace("Adjusting Contrast. Grid Sizes: %s", grid_sizes) + + clahes = [cv2.createCLAHE(clipLimit=2.0, # pylint: disable=no-member + tileGridSize=(grid_size, grid_size)) + for grid_size in grid_sizes] + + for idx, clahe in zip(indices, clahes): + batch[idx, :, :, 0] = clahe.apply(batch[idx, :, :, 0]) + return batch + + def _random_lab(self, batch): + """ Perform random color/lightness adjustment in L*a*b* colorspace on a batch of images """ + amount_l = self._config.get("color_lightness", 30) / 100 + amount_ab = self._config.get("color_ab", 8) / 100 + adjust = np.array([amount_l, amount_ab, amount_ab], dtype="float32") + randoms = ( + (np.random.rand(self._batchsize, 1, 1, 3).astype("float32") * (adjust * 2)) - adjust) + logger.trace("Random LAB adjustments: %s", randoms) + + for image, rand in zip(batch, randoms): + for idx in range(rand.shape[-1]): + adjustment = rand[:, :, idx] + if adjustment >= 0: + image[:, :, idx] = ((255 - image[:, :, idx]) * adjustment) + image[:, :, idx] + else: + image[:, :, idx] = image[:, :, idx] * (1 + adjustment) + return batch + + # <<< IMAGE AUGMENTATION >>> # + def transform(self, batch): + """ Perform random transformation on the passed in batch. + + The transformation parameters are set in :file:`config.train.ini` + + Parameters + ---------- + batch: numpy.ndarray + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `channels`) and in `BGR` format. + + Returns + ---------- + numpy.ndarray + A 4-dimensional array of the same shape as :attr:`batch` with transformation applied. + """ + if self.is_display: + return batch logger.trace("Randomly transforming image") - height, width = image.shape[0:2] + rotation_range = self._config.get("rotation_range", 10) + zoom_range = self._config.get("zoom_range", 5) / 100 + shift_range = self._config.get("shift_range", 5) / 100 - rotation_range = self.config.get("rotation_range", 10) - rotation = np.random.uniform(-rotation_range, rotation_range) + rotation = np.random.uniform(-rotation_range, + rotation_range, + size=self._batchsize).astype("float32") + scale = np.random.uniform(1 - zoom_range, + 1 + zoom_range, + size=self._batchsize).astype("float32") + tform = np.random.uniform( + -shift_range, + shift_range, + size=(self._batchsize, 2)).astype("float32") * self._training_size - zoom_range = self.config.get("zoom_range", 5) / 100 - scale = np.random.uniform(1 - zoom_range, 1 + zoom_range) + mats = np.array( + [cv2.getRotationMatrix2D((self._training_size // 2, self._training_size // 2), + rot, + scl) + for rot, scl in zip(rotation, scale)]).astype("float32") + mats[..., 2] += tform - shift_range = self.config.get("shift_range", 5) / 100 - tnx = np.random.uniform(-shift_range, shift_range) * width - tny = np.random.uniform(-shift_range, shift_range) * height - - mat = cv2.getRotationMatrix2D( # pylint:disable=no-member - (width // 2, height // 2), rotation, scale) - mat[:, 2] += (tnx, tny) - result = cv2.warpAffine( # pylint:disable=no-member - image, mat, (width, height), - borderMode=cv2.BORDER_REPLICATE) # pylint:disable=no-member + batch = np.array([cv2.warpAffine(image, + mat, + (self._training_size, self._training_size), + borderMode=cv2.BORDER_REPLICATE) + for image, mat in zip(batch, mats)]) logger.trace("Randomly transformed image") - return result + return batch - def do_random_flip(self, image): - """ Perform flip on image if random number is within threshold """ - logger.trace("Randomly flipping image") - random_flip = self.config.get("random_flip", 50) / 100 - if np.random.random() < random_flip: - logger.trace("Flip within threshold. Flipping") - retval = image[:, ::-1] - else: - logger.trace("Flip outside threshold. Not Flipping") - retval = image - logger.trace("Randomly flipped image") - return retval + def random_flip(self, batch): + """ Perform random horizontal flipping on the passed in batch. - def random_warp(self, image): - """ get pair of random warped images from aligned face image """ - logger.trace("Randomly warping image") - height, width = image.shape[0:2] - coverage = self.get_coverage(image) // 2 - try: - assert height == width and height % 2 == 0 - except AssertionError as err: - msg = ("Training images should be square with an even number of pixels across each " - "side. An image was found with width: {}, height: {}." - "\nMost likely this is a frame rather than a face within your training set. " - "\nMake sure that the only images within your training set are faces generated " - "from the Extract process.".format(width, height)) - raise FaceswapError(msg) from err + The probability of flipping an image is set in :file:`config.train.ini` - range_ = np.linspace(height // 2 - coverage, height // 2 + coverage, 5, dtype='float32') - mapx = np.broadcast_to(range_, (5, 5)).copy() - mapy = mapx.T - # mapx, mapy = np.float32(np.meshgrid(range_,range_)) # instead of broadcast + Parameters + ---------- + batch: numpy.ndarray + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `channels`) and in `BGR` format. - pad = int(1.25 * self.input_size) - slices = slice(pad // 10, -pad // 10) - dst_slices = [slice(0, (size + 1), (size // 4)) for size in self.output_sizes] - interp = np.empty((2, self.input_size, self.input_size), dtype='float32') + Returns + ---------- + numpy.ndarray + A 4-dimensional array of the same shape as :attr:`batch` with transformation applied. + """ + if not self.is_display: + logger.trace("Randomly flipping image") + randoms = np.random.rand(self._batchsize) + indices = np.where(randoms > self._config.get("random_flip", 50) / 100)[0] + batch[indices] = batch[indices, :, ::-1] + logger.trace("Randomly flipped %s images of %s", len(indices), self._batchsize) + return batch - for i, map_ in enumerate([mapx, mapy]): - map_ = map_ + np.random.normal(size=(5, 5), scale=self.scale) - interp[i] = cv2.resize(map_, (pad, pad))[slices, slices] # pylint:disable=no-member + def warp(self, batch, to_landmarks=False, **kwargs): + """ Perform random warping on the passed in batch by one of two methods. - warped_image = cv2.remap( # pylint:disable=no-member - image, interp[0], interp[1], cv2.INTER_LINEAR) # pylint:disable=no-member - logger.trace("Warped image shape: %s", warped_image.shape) + Parameters + ---------- + batch: numpy.ndarray + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `3`) and in `BGR` format. + to_landmarks: bool, optional + If ``False`` perform standard random warping of the input image. If ``True`` perform + warping to semi-random similar corresponding landmarks from the other side. Default: + ``False`` + kwargs: dict + If :attr:`to_landmarks` is ``True`` the following additional kwargs must be passed in: - src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1) - dst_points = [np.mgrid[dst_slice, dst_slice] for dst_slice in dst_slices] - mats = [umeyama(src_points, True, dst_pts.T.reshape(-1, 2))[0:2] - for dst_pts in dst_points] + * **batch_src_points** (`numpy.ndarray`) - A batch of 68 point landmarks for the \ + source faces. This is a 3-dimensional array in the shape (`batchsize`, `68`, `2`). - target_images = [cv2.warpAffine(image, # pylint:disable=no-member - mat, - (self.output_sizes[idx], self.output_sizes[idx])) - for idx, mat in enumerate(mats)] + * **batch_dst_points** (`numpy.ndarray`) - A batch of randomly chosen closest match \ + destination faces landmarks. This is a 3-dimensional array in the shape (`batchsize`, \ + `68`, `2`). + Returns + ---------- + numpy.ndarray + A 4-dimensional array of the same shape as :attr:`batch` with warping applied. + """ + if to_landmarks: + return self._random_warp_landmarks(batch, **kwargs).astype("float32") / 255.0 + return self._random_warp(batch).astype("float32") / 255.0 - logger.trace("Target image shapes: %s", [tgt.shape for tgt in target_images]) - return self.compile_images(warped_image, target_images) + def _random_warp(self, batch): + """ Randomly warp the input batch """ + logger.trace("Randomly warping batch") + mapx = self._constants["warp_mapx"] + mapy = self._constants["warp_mapy"] + pad = self._constants["warp_pad"] + slices = self._constants["warp_slices"] - def random_warp_landmarks(self, image, src_points=None, dst_points=None): - """ get warped image, target image and target mask - From DFAKER plugin """ + rands = np.random.normal(size=(self._batchsize, 2, 5, 5), + scale=self._scale).astype("float32") + batch_maps = np.stack((mapx, mapy), axis=1) + rands + batch_interp = np.array([[cv2.resize(map_, (pad, pad))[slices, slices] for map_ in maps] + for maps in batch_maps]) + warped_batch = np.array([cv2.remap(image, interp[0], interp[1], cv2.INTER_LINEAR) + for image, interp in zip(batch, batch_interp)]) + + logger.trace("Warped image shape: %s", warped_batch.shape) + return warped_batch + + def _random_warp_landmarks(self, batch, batch_src_points, batch_dst_points): + """ From dfaker. Warp the image to a similar set of landmarks from the opposite side """ logger.trace("Randomly warping landmarks") - size = image.shape[0] - coverage = self.get_coverage(image) // 2 + edge_anchors = self._constants["warp_lm_edge_anchors"] + grids = self._constants["warp_lm_grids"] + slices = self._constants["tgt_slices"] - p_mx = size - 1 - p_hf = (size // 2) - 1 + batch_dst = (batch_dst_points + np.random.normal(size=batch_dst_points.shape, + scale=2.0)).astype("int32") - edge_anchors = [(0, 0), (0, p_mx), (p_mx, p_mx), (p_mx, 0), - (p_hf, 0), (p_hf, p_mx), (p_mx, p_hf), (0, p_hf)] - grid_x, grid_y = np.mgrid[0:p_mx:complex(size), 0:p_mx:complex(size)] + face_cores = [cv2.convexHull(np.concatenate([src[17:], dst[17:]], axis=0)) + for src, dst in zip(batch_src_points, batch_dst)] - source = src_points - destination = (dst_points.copy().astype('float32') + - np.random.normal(size=dst_points.shape, scale=2.0)) - destination = destination.astype('uint8') + batch_src = np.append(batch_src_points, edge_anchors, axis=1) + batch_dst = np.append(batch_dst, edge_anchors, axis=1) - face_core = cv2.convexHull(np.concatenate( # pylint:disable=no-member - [source[17:], destination[17:]], axis=0).astype(int)) + rem_indices = [list(set(idx for fpl in (src, dst) + for idx, (pty, ptx) in enumerate(fpl) + if cv2.pointPolygonTest(face_core, (pty, ptx), False) >= 0)) + for src, dst, face_core in zip(batch_src[:, :18, :], + batch_dst[:, :18, :], + face_cores)] + batch_src = [np.delete(src, idxs, axis=0) for idxs, src in zip(rem_indices, batch_src)] + batch_dst = [np.delete(dst, idxs, axis=0) for idxs, dst in zip(rem_indices, batch_dst)] - source = [(pty, ptx) for ptx, pty in source] + edge_anchors - destination = [(pty, ptx) for ptx, pty in destination] + edge_anchors - - indicies_to_remove = set() - for fpl in source, destination: - for idx, (pty, ptx) in enumerate(fpl): - if idx > 17: - break - elif cv2.pointPolygonTest(face_core, # pylint:disable=no-member - (pty, ptx), - False) >= 0: - indicies_to_remove.add(idx) - - for idx in sorted(indicies_to_remove, reverse=True): - source.pop(idx) - destination.pop(idx) - - grid_z = griddata(destination, source, (grid_x, grid_y), method="linear") - map_x = np.append([], [ar[:, 1] for ar in grid_z]).reshape(size, size) - map_y = np.append([], [ar[:, 0] for ar in grid_z]).reshape(size, size) - map_x_32 = map_x.astype('float32') - map_y_32 = map_y.astype('float32') - - warped_image = cv2.remap(image, # pylint:disable=no-member - map_x_32, - map_y_32, - cv2.INTER_LINEAR, # pylint:disable=no-member - cv2.BORDER_TRANSPARENT) # pylint:disable=no-member - target_image = image - - # TODO Make sure this replacement is correct - slices = slice(size // 2 - coverage, size // 2 + coverage) -# slices = slice(size // 32, size - size // 32) # 8px on a 256px image - warped_image = cv2.resize( # pylint:disable=no-member - warped_image[slices, slices, :], (self.input_size, self.input_size), - cv2.INTER_AREA) # pylint:disable=no-member - logger.trace("Warped image shape: %s", warped_image.shape) - target_images = [cv2.resize(target_image[slices, slices, :], # pylint:disable=no-member - (size, size), - cv2.INTER_AREA) # pylint:disable=no-member - for size in self.output_sizes] - - logger.trace("Target image shapea: %s", [img.shape for img in target_images]) - return self.compile_images(warped_image, target_images) - - def compile_images(self, warped_image, target_images): - """ Compile the warped images, target images and mask for feed """ - warped_image, _ = self.separate_mask(warped_image) - final_target_images = list() - target_mask = None - for target_image in target_images: - image, mask = self.separate_mask(target_image) - final_target_images.append(image) - # Add the mask if it exists and is the same size as our largest output - if mask is not None and mask.shape[1] == max(self.output_sizes): - target_mask = mask - - retval = [warped_image] + final_target_images - if target_mask is not None: - logger.trace("Target mask shape: %s", target_mask.shape) - retval.append(target_mask) - - logger.trace("Final shapes: %s", [img.shape for img in retval]) - return retval - - -def stack_images(images): - """ Stack images """ - logger.debug("Stack images") - - def get_transpose_axes(num): - if num % 2 == 0: - logger.debug("Even number of images to stack") - y_axes = list(range(1, num - 1, 2)) - x_axes = list(range(0, num - 1, 2)) - else: - logger.debug("Odd number of images to stack") - y_axes = list(range(0, num - 1, 2)) - x_axes = list(range(1, num - 1, 2)) - return y_axes, x_axes, [num - 1] - - images_shape = np.array(images.shape) - new_axes = get_transpose_axes(len(images_shape)) - new_shape = [np.prod(images_shape[x]) for x in new_axes] - logger.debug("Stacked images") - return np.transpose( - images, - axes=np.concatenate(new_axes) - ).reshape(new_shape) + grid_z = np.array([griddata(dst, src, (grids[0], grids[1]), method="linear") + for src, dst in zip(batch_src, batch_dst)]) + maps = grid_z.reshape(self._batchsize, + self._training_size, + self._training_size, + 2).astype("float32") + warped_batch = np.array([cv2.remap(image, + map_[..., 1], + map_[..., 0], + cv2.INTER_LINEAR, + cv2.BORDER_TRANSPARENT) + for image, map_ in zip(batch, maps)]) + warped_batch = np.array([cv2.resize(image[slices, slices, :], + (self._input_size, self._input_size), + cv2.INTER_AREA) + for image in warped_batch]) + logger.trace("Warped batch shape: %s", warped_batch.shape) + return warped_batch diff --git a/lib/utils.py b/lib/utils.py index 88a527c..ee8c6ff 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -4,27 +4,18 @@ import json import logging import os -import subprocess import sys import urllib import warnings import zipfile -from hashlib import sha1 + from pathlib import Path from re import finditer from multiprocessing import current_process from socket import timeout as socket_timeout, error as socket_error -import imageio_ffmpeg as im_ffm from tqdm import tqdm -import numpy as np -import cv2 - - -from lib.faces_detect import DetectedFace - - # Global variables _image_extensions = [ # pylint:disable=invalid-name ".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"] @@ -132,6 +123,22 @@ def get_image_paths(directory): return dir_contents +def convert_to_secs(*args): + """ converts a time to second. Either convert_to_secs(min, secs) or + convert_to_secs(hours, mins, secs). """ + logger = logging.getLogger(__name__) # pylint:disable=invalid-name + logger.debug("from time: %s", args) + retval = 0.0 + if len(args) == 1: + retval = float(args[0]) + elif len(args) == 2: + retval = 60 * float(args[0]) + float(args[1]) + elif len(args) == 3: + retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2]) + logger.debug("to secs: %s", retval) + return retval + + def full_path_split(path): """ Split a given path into all of it's separate components """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name @@ -151,147 +158,6 @@ def full_path_split(path): return allparts -def cv2_read_img(filename, raise_error=False): - """ Read an image with cv2 and check that an image was actually loaded. - Logs an error if the image returned is None. or an error has occured. - - Pass raise_error=True if error should be raised """ - logger = logging.getLogger(__name__) # pylint:disable=invalid-name - logger.trace("Requested image: '%s'", filename) - success = True - image = None - try: - image = cv2.imread(filename) # pylint:disable=no-member,c-extension-no-member - if image is None: - raise ValueError - except TypeError: - success = False - msg = "Error while reading image (TypeError): '{}'".format(filename) - logger.error(msg) - if raise_error: - raise Exception(msg) - except ValueError: - success = False - msg = ("Error while reading image. This is most likely caused by special characters in " - "the filename: '{}'".format(filename)) - logger.error(msg) - if raise_error: - raise Exception(msg) - except Exception as err: # pylint:disable=broad-except - success = False - msg = "Failed to load image '{}'. Original Error: {}".format(filename, str(err)) - logger.error(msg) - if raise_error: - raise Exception(msg) - logger.trace("Loaded image: '%s'. Success: %s", filename, success) - return image - - -def hash_image_file(filename): - """ Return an image file's sha1 hash """ - logger = logging.getLogger(__name__) # pylint:disable=invalid-name - img = cv2_read_img(filename, raise_error=True) - img_hash = sha1(img).hexdigest() - logger.trace("filename: '%s', hash: %s", filename, img_hash) - return img_hash - - -def hash_encode_image(image, extension): - """ Encode the image, get the hash and return the hash with - encoded image """ - img = cv2.imencode(extension, image)[1] # pylint:disable=no-member,c-extension-no-member - f_hash = sha1( - cv2.imdecode( # pylint:disable=no-member,c-extension-no-member - img, - cv2.IMREAD_UNCHANGED)).hexdigest() # pylint:disable=no-member,c-extension-no-member - return f_hash, img - - -def convert_to_secs(*args): - """ converts a time to second. Either convert_to_secs(min, secs) or - convert_to_secs(hours, mins, secs). """ - logger = logging.getLogger(__name__) # pylint:disable=invalid-name - logger.debug("from time: %s", args) - retval = 0.0 - if len(args) == 1: - retval = float(args[0]) - elif len(args) == 2: - retval = 60 * float(args[0]) + float(args[1]) - elif len(args) == 3: - retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2]) - logger.debug("to secs: %s", retval) - return retval - - -def count_frames_and_secs(path, timeout=60): - """ - Adapted From ffmpeg_imageio, to handle occasional hanging issue: - https://github.com/imageio/imageio-ffmpeg - - Get the number of frames and number of seconds for the given video - file. Note that this operation can be quite slow for large files. - - Disclaimer: I've seen this produce different results from actually reading - the frames with older versions of ffmpeg (2.x). Therefore I cannot say - with 100% certainty that the returned values are always exact. - """ - # https://stackoverflow.com/questions/2017843/fetch-frame-count-with-ffmpeg - - logger = logging.getLogger(__name__) # pylint:disable=invalid-name - assert isinstance(path, str), "Video path must be a string" - exe = im_ffm.get_ffmpeg_exe() - iswin = sys.platform.startswith("win") - logger.debug("iswin: '%s'", iswin) - cmd = [exe, "-i", path, "-map", "0:v:0", "-c", "copy", "-f", "null", "-"] - logger.debug("FFMPEG Command: '%s'", " ".join(cmd)) - attempts = 3 - for attempt in range(attempts): - try: - logger.debug("attempt: %s of %s", attempt + 1, attempts) - out = subprocess.check_output(cmd, - stderr=subprocess.STDOUT, - shell=iswin, - timeout=timeout) - logger.debug("Succesfully communicated with FFMPEG") - break - except subprocess.CalledProcessError as err: - out = err.output.decode(errors="ignore") - raise RuntimeError("FFMEG call failed with {}:\n{}".format(err.returncode, out)) - except subprocess.TimeoutExpired as err: - this_attempt = attempt + 1 - if this_attempt == attempts: - msg = ("FFMPEG hung while attempting to obtain the frame count. " - "Sometimes this issue resolves itself, so you can try running again. " - "Otherwise use the Effmpeg Tool to extract the frames from your video into " - "a folder, and then run the requested Faceswap process on that folder.") - raise FaceswapError(msg) from err - logger.warning("FFMPEG hung while attempting to obtain the frame count. " - "Retrying %s of %s", this_attempt + 1, attempts) - continue - - # Note that other than with the subprocess calls below, ffmpeg wont hang here. - # Worst case Python will stop/crash and ffmpeg will continue running until done. - - nframes = nsecs = None - for line in reversed(out.splitlines()): - if not line.startswith(b"frame="): - continue - line = line.decode(errors="ignore") - logger.debug("frame line: '%s'", line) - idx = line.find("frame=") - if idx >= 0: - splitframes = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip() - nframes = int(splitframes) - idx = line.find("time=") - if idx >= 0: - splittime = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip() - nsecs = convert_to_secs(*splittime.split(":")) - logger.debug("nframes: %s, nsecs: %s", nframes, nsecs) - return nframes, nsecs - - raise RuntimeError("Could not get number of frames") # pragma: no cover - - def backup_file(directory, filename): """ Backup a given file by appending .bk to the end """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name @@ -348,80 +214,6 @@ def deprecation_warning(func_name, additional_info=None): logger.warning(msg) -def rotate_landmarks(face, rotation_matrix): - # pylint:disable=c-extension-no-member - """ Rotate the landmarks and bounding box for faces - found in rotated images. - Pass in a DetectedFace object or Alignments dict """ - logger = logging.getLogger(__name__) # pylint:disable=invalid-name - logger.trace("Rotating landmarks: (rotation_matrix: %s, type(face): %s", - rotation_matrix, type(face)) - rotated_landmarks = None - # Detected Face Object - if isinstance(face, DetectedFace): - bounding_box = [[face.x, face.y], - [face.x + face.w, face.y], - [face.x + face.w, face.y + face.h], - [face.x, face.y + face.h]] - landmarks = face.landmarks_xy - - # Alignments Dict - elif isinstance(face, dict) and "x" in face: - bounding_box = [[face.get("x", 0), face.get("y", 0)], - [face.get("x", 0) + face.get("w", 0), - face.get("y", 0)], - [face.get("x", 0) + face.get("w", 0), - face.get("y", 0) + face.get("h", 0)], - [face.get("x", 0), - face.get("y", 0) + face.get("h", 0)]] - landmarks = face.get("landmarks_xy", list()) - - else: - raise ValueError("Unsupported face type") - - logger.trace("Original landmarks: %s", landmarks) - - rotation_matrix = cv2.invertAffineTransform( # pylint:disable=no-member - rotation_matrix) - rotated = list() - for item in (bounding_box, landmarks): - if not item: - continue - points = np.array(item, np.int32) - points = np.expand_dims(points, axis=0) - transformed = cv2.transform(points, # pylint:disable=no-member - rotation_matrix).astype(np.int32) - rotated.append(transformed.squeeze()) - - # Bounding box should follow x, y planes, so get min/max - # for non-90 degree rotations - pt_x = min([pnt[0] for pnt in rotated[0]]) - pt_y = min([pnt[1] for pnt in rotated[0]]) - pt_x1 = max([pnt[0] for pnt in rotated[0]]) - pt_y1 = max([pnt[1] for pnt in rotated[0]]) - width = pt_x1 - pt_x - height = pt_y1 - pt_y - - if isinstance(face, DetectedFace): - face.x = int(pt_x) - face.y = int(pt_y) - face.w = int(width) - face.h = int(height) - face.r = 0 - if len(rotated) > 1: - rotated_landmarks = [tuple(point) for point in rotated[1].tolist()] - face.landmarks_xy = rotated_landmarks - else: - face["left"] = int(pt_x) - face["top"] = int(pt_y) - face["right"] = int(pt_x1) - face["bottom"] = int(pt_y1) - rotated_landmarks = face - - logger.trace("Rotated landmarks: %s", rotated_landmarks) - return face - - def camel_case_split(identifier): """ Split a camel case name from: https://stackoverflow.com/questions/29916065 """ diff --git a/plugins/extract/detect/_base.py b/plugins/extract/detect/_base.py index 5791159..6a29de1 100755 --- a/plugins/extract/detect/_base.py +++ b/plugins/extract/detect/_base.py @@ -18,8 +18,7 @@ To get a :class:`~lib.faces_detect.DetectedFace` object use the function: import cv2 import numpy as np -from lib.faces_detect import DetectedFace -from lib.utils import rotate_landmarks +from lib.faces_detect import DetectedFace, rotate_landmarks from plugins.extract._base import Extractor, logger diff --git a/plugins/train/model/_base.py b/plugins/train/model/_base.py index caa16fd..69d51dc 100644 --- a/plugins/train/model/_base.py +++ b/plugins/train/model/_base.py @@ -9,6 +9,7 @@ import os import sys import time +from concurrent import futures from json import JSONDecodeError import keras @@ -24,7 +25,6 @@ from lib.model.losses import (DSSIMObjective, PenalizedLoss, gradient_loss, mask generalized_loss, l_inf_norm, gmsd_loss, gaussian_blur) from lib.model.nn_blocks import NNBlocks from lib.model.optimizers import Adam -from lib.multithreading import MultiThread from lib.utils import deprecation_warning, FaceswapError from plugins.train._config import Config @@ -466,21 +466,13 @@ class ModelBase(): backup_func = self.backup.backup_model if self.should_backup(save_averages) else None if backup_func: logger.info("Backing up models...") - save_threads = list() - for network in self.networks.values(): - name = "save_{}".format(network.name) - save_threads.append(MultiThread(network.save, - name=name, - backup_func=backup_func)) - save_threads.append(MultiThread(self.state.save, - name="save_state", - backup_func=backup_func)) - for thread in save_threads: - thread.start() - for thread in save_threads: - if thread.has_error: - logger.error(thread.errors[0]) - thread.join() + executor = futures.ThreadPoolExecutor() + save_threads = [executor.submit(network.save, backup_func=backup_func) + for network in self.networks.values()] + save_threads.append(executor.submit(self.state.save, backup_func=backup_func)) + futures.wait(save_threads) + # call result() to capture errors + _ = [thread.result() for thread in save_threads] msg = "[Saved models]" if save_averages: lossmsg = ["{}_{}: {:.5f}".format(self.state.loss_names[side][0], diff --git a/plugins/train/trainer/_base.py b/plugins/train/trainer/_base.py index 65825a0..e197435 100644 --- a/plugins/train/trainer/_base.py +++ b/plugins/train/trainer/_base.py @@ -33,7 +33,7 @@ from tensorflow.python import errors_impl as tf_errors # pylint:disable=no-name from lib.alignments import Alignments from lib.faces_detect import DetectedFace -from lib.training_data import TrainingDataGenerator, stack_images +from lib.training_data import TrainingDataGenerator from lib.utils import FaceswapError, get_folder, get_image_paths from plugins.train._config import Config @@ -292,10 +292,10 @@ class Batcher(): """ Return the next batch from the generator Items should come out as: (warped, target [, mask]) """ batch = next(self.feed) - feed = batch[1] - batch = batch[2:] # Remove full size samples and feed from batch - mask = batch[-1] - batch = [[feed, mask], batch] if self.use_mask else [feed, batch] + if self.use_mask: + batch = [[batch["feed"], batch["masks"]], batch["targets"] + [batch["masks"]]] + else: + batch = [batch["feed"], batch["targets"]] self.generate_preview(do_preview) return batch @@ -309,13 +309,10 @@ class Batcher(): if self.preview_feed is None: self.set_preview_feed() batch = next(self.preview_feed) - self.samples, feed = batch[:2] - batch = batch[2:] # Remove full size samples and feed from batch - self.target = batch[self.model.largest_face_index] + self.samples = batch["samples"] + self.target = [batch["targets"][self.model.largest_face_index]] if self.use_mask: - mask = batch[-1] - batch = [[feed, mask], batch] - self.target = [self.target, mask] + self.target += [batch["masks"]] def set_preview_feed(self): """ Set the preview dictionary """ @@ -347,15 +344,11 @@ class Batcher(): def compile_timelapse_sample(self): """ Timelapse samples """ batch = next(self.timelapse_feed) - samples, feed = batch[:2] - batchsize = len(samples) - batch = batch[2:] # Remove full size samples and feed from batch - images = batch[self.model.largest_face_index] + batchsize = len(batch["samples"]) + images = [batch["targets"][self.model.largest_face_index]] if self.use_mask: - mask = batch[-1] - batch = [[feed, mask], batch] - images = [images, mask] - sample = self.compile_sample(batchsize, samples=samples, images=images) + images = images + [batch["masks"]] + sample = self.compile_sample(batchsize, samples=batch["samples"], images=images) return sample def set_timelapse_feed(self, images, batchsize): @@ -405,10 +398,10 @@ class Samples(): for side, samples in self.images.items(): other_side = "a" if side == "b" else "b" - predictions = [preds["{}_{}".format(side, side)], + predictions = [preds["{0}_{0}".format(side)], preds["{}_{}".format(other_side, side)]] display = self.to_full_frame(side, samples, predictions) - headers[side] = self.get_headers(side, other_side, display[0].shape[1]) + headers[side] = self.get_headers(side, display[0].shape[1]) figures[side] = np.stack([display[0], display[1], display[2], ], axis=1) if self.images[side][0].shape[0] % 2 == 1: figures[side] = np.concatenate([figures[side], @@ -547,22 +540,22 @@ class Samples(): logger.debug("Overlayed foreground. Shape: %s", retval.shape) return retval - def get_headers(self, side, other_side, width): + def get_headers(self, side, width): """ Set headers for images """ - logger.debug("side: '%s', other_side: '%s', width: %s", - side, other_side, width) + logger.debug("side: '%s', width: %s", + side, width) + titles = ("Original", "Swap") if side == "a" else ("Swap", "Original") side = side.upper() - other_side = other_side.upper() height = int(64 * self.scaling) total_width = width * 3 logger.debug("height: %s, total_width: %s", height, total_width) font = cv2.FONT_HERSHEY_SIMPLEX # pylint: disable=no-member - texts = ["Target {}".format(side), - "{} > {}".format(side, side), - "{} > {}".format(side, other_side)] + texts = ["{} ({})".format(titles[0], side), + "{0} > {0}".format(titles[0]), + "{} > {}".format(titles[0], titles[1])] text_sizes = [cv2.getTextSize(texts[idx], # pylint: disable=no-member font, - self.scaling, + self.scaling * 0.8, 1)[0] for idx in range(len(texts))] text_y = int((height + text_sizes[0][1]) / 2) @@ -576,7 +569,7 @@ class Samples(): text, (text_x[idx], text_y), font, - self.scaling, + self.scaling * 0.8, (0, 0, 0), 1, lineType=cv2.LINE_AA) # pylint: disable=no-member @@ -703,3 +696,25 @@ class Landmarks(): detected_face.load_aligned(None, size=self.size) landmarks[detected_face.hash] = detected_face.aligned_landmarks return landmarks + + +def stack_images(images): + """ Stack images """ + logger.debug("Stack images") + + def get_transpose_axes(num): + if num % 2 == 0: + logger.debug("Even number of images to stack") + y_axes = list(range(1, num - 1, 2)) + x_axes = list(range(0, num - 1, 2)) + else: + logger.debug("Odd number of images to stack") + y_axes = list(range(0, num - 1, 2)) + x_axes = list(range(1, num - 1, 2)) + return y_axes, x_axes, [num - 1] + + images_shape = np.array(images.shape) + new_axes = get_transpose_axes(len(images_shape)) + new_shape = [np.prod(images_shape[x]) for x in new_axes] + logger.debug("Stacked images") + return np.transpose(images, axes=np.concatenate(new_axes)).reshape(new_shape) diff --git a/scripts/convert.py b/scripts/convert.py index c184b1f..a8f4e35 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -17,9 +17,10 @@ from lib import Serializer from lib.convert import Converter from lib.faces_detect import DetectedFace from lib.gpu_stats import GPUStats +from lib.image import read_image_hash from lib.multithreading import MultiThread, total_cpus from lib.queue_manager import queue_manager -from lib.utils import FaceswapError, get_folder, get_image_paths, hash_image_file +from lib.utils import FaceswapError, get_folder, get_image_paths from plugins.extract.pipeline import Extractor from plugins.plugin_loader import PluginLoader @@ -682,7 +683,7 @@ class OptionalActions(): file_list = [path for path in get_image_paths(input_aligned_dir)] logger.info("Getting Face Hashes for selected Aligned Images") for face in tqdm(file_list, desc="Hashing Faces"): - face_hashes.append(hash_image_file(face)) + face_hashes.append(read_image_hash(face)) logger.debug("Face Hashes: %s", (len(face_hashes))) if not face_hashes: raise FaceswapError("Aligned directory is empty, no faces will be converted!") @@ -746,5 +747,5 @@ class Legacy(): continue hash_faces = all_faces[frame] for index, face_path in hash_faces.items(): - hash_faces[index] = hash_image_file(face_path) + hash_faces[index] = read_image_hash(face_path) self.alignments.add_face_hashes(frame, hash_faces) diff --git a/scripts/extract.py b/scripts/extract.py index 948aae6..5306c3e 100644 --- a/scripts/extract.py +++ b/scripts/extract.py @@ -8,9 +8,10 @@ from pathlib import Path from tqdm import tqdm +from lib.image import encode_image_with_hash from lib.multithreading import MultiThread from lib.queue_manager import queue_manager -from lib.utils import get_folder, hash_encode_image, deprecation_warning +from lib.utils import get_folder, deprecation_warning from plugins.extract.pipeline import Extractor from scripts.fsmedia import Alignments, Images, PostProcess, Utils @@ -255,7 +256,7 @@ class Extract(): face = detected_face["face"] resized_face = face.aligned_face - face.hash, img = hash_encode_image(resized_face, extension) + face.hash, img = encode_image_with_hash(resized_face, extension) self.save_queue.put((out_filename, img)) final_faces.append(face.to_alignment()) self.alignments.data[os.path.basename(filename)] = final_faces diff --git a/scripts/fsmedia.py b/scripts/fsmedia.py index 5a763cd..b9cd8c8 100644 --- a/scripts/fsmedia.py +++ b/scripts/fsmedia.py @@ -16,8 +16,9 @@ import numpy as np from lib.aligner import Extract as AlignerExtract from lib.alignments import Alignments as AlignmentsBase from lib.face_filter import FaceFilter as FilterFunc -from lib.utils import (camel_case_split, count_frames_and_secs, cv2_read_img, get_folder, - get_image_paths, set_system_verbosity, _video_extensions) +from lib.image import count_frames_and_secs, read_image +from lib.utils import (camel_case_split, get_folder, get_image_paths, set_system_verbosity, + _video_extensions) logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -183,7 +184,7 @@ class Images(): """ Load frames from disk """ logger.debug("Input is separate Frames. Loading images") for filename in self.input_images: - image = cv2_read_img(filename, raise_error=False) + image = read_image(filename, raise_error=False) if image is None: continue yield filename, image @@ -212,7 +213,7 @@ class Images(): logger.trace("Extracted frame_no %s from filename '%s'", frame_no, filename) retval = self.load_one_video_frame(int(frame_no)) else: - retval = cv2_read_img(filename, raise_error=True) + retval = read_image(filename, raise_error=True) return retval def load_one_video_frame(self, frame_no): diff --git a/scripts/train.py b/scripts/train.py index f834d0d..3ec713b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -12,10 +12,11 @@ import cv2 import tensorflow as tf from keras.backend.tensorflow_backend import set_session +from lib.image import read_image from lib.keypress import KBHit from lib.multithreading import MultiThread -from lib.queue_manager import queue_manager -from lib.utils import cv2_read_img, get_folder, get_image_paths, set_system_verbosity +from lib.queue_manager import queue_manager # noqa pylint:disable=unused-import +from lib.utils import get_folder, get_image_paths, set_system_verbosity from plugins.plugin_loader import PluginLoader logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -176,7 +177,7 @@ class Train(): @property def image_size(self): """ Get the training set image size for storing in model data """ - image = cv2_read_img(self.images["a"][0], raise_error=True) + image = read_image(self.images["a"][0], raise_error=True) size = image.shape[0] logger.debug("Training image size: %s", size) return size diff --git a/tools/lib_alignments/media.py b/tools/lib_alignments/media.py index 6dd867c..aae4661 100644 --- a/tools/lib_alignments/media.py +++ b/tools/lib_alignments/media.py @@ -14,8 +14,8 @@ from tqdm import tqdm from lib.aligner import Extract as AlignerExtract from lib.alignments import Alignments from lib.faces_detect import DetectedFace -from lib.utils import (_image_extensions, _video_extensions, count_frames_and_secs, cv2_read_img, - hash_image_file, hash_encode_image) +from lib.image import count_frames_and_secs, encode_image_with_hash, read_image, read_image_hash +from lib.utils import _image_extensions, _video_extensions logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -175,7 +175,7 @@ class MediaLoader(): else: src = os.path.join(self.folder, filename) logger.trace("Loading image: '%s'", src) - image = cv2_read_img(src, raise_error=True) + image = read_image(src, raise_error=True) return image def load_video_frame(self, filename): @@ -210,7 +210,7 @@ class Faces(MediaLoader): continue filename = os.path.splitext(face)[0] file_extension = os.path.splitext(face)[1] - face_hash = hash_image_file(os.path.join(self.folder, face)) + face_hash = read_image_hash(os.path.join(self.folder, face)) retval = {"face_fullname": face, "face_name": filename, "face_extension": file_extension, @@ -358,7 +358,7 @@ class ExtractedFaces(): @staticmethod def save_face_with_hash(filename, extension, face): """ Save a face and return it's hash """ - f_hash, img = hash_encode_image(face, extension) + f_hash, img = encode_image_with_hash(face, extension) logger.trace("Saving face: '%s'", filename) with open(filename, "wb") as out_file: out_file.write(img) diff --git a/tools/sort.py b/tools/sort.py index b7c3f4e..9baf04f 100644 --- a/tools/sort.py +++ b/tools/sort.py @@ -16,8 +16,8 @@ from tqdm import tqdm from lib.cli import FullHelpArgumentParser from lib import Serializer from lib.faces_detect import DetectedFace +from lib.image import read_image from lib.queue_manager import queue_manager -from lib.utils import cv2_read_img from lib.vgg_face2_keras import VGGFace2 as VGGFace from plugins.plugin_loader import PluginLoader @@ -106,7 +106,7 @@ class Sort(): @staticmethod def get_landmarks(filename): """ Extract the face from a frame (If not alignments file found) """ - image = cv2_read_img(filename, raise_error=True) + image = read_image(filename, raise_error=True) feed = Sort.alignment_dict(image) feed["filename"] = filename queue_manager.get_queue("in").put(feed) @@ -161,7 +161,7 @@ class Sort(): logger.info("Sorting by face similarity...") images = np.array(self.find_images(input_dir)) - preds = np.array([self.vgg_face.predict(cv2_read_img(img, raise_error=True)) + preds = np.array([self.vgg_face.predict(read_image(img, raise_error=True)) for img in tqdm(images, desc="loading", file=sys.stdout)]) logger.info("Sorting. Depending on ths size of your dataset, this may take a few " "minutes...") @@ -264,7 +264,7 @@ class Sort(): logger.info("Sorting by histogram similarity...") img_list = [ - [img, cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256])] + [img, cv2.calcHist([read_image(img, raise_error=True)], [0], None, [256], [0, 256])] for img in tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout) ] @@ -294,7 +294,7 @@ class Sort(): img_list = [ [img, - cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256]), 0] + cv2.calcHist([read_image(img, raise_error=True)], [0], None, [256], [0, 256]), 0] for img in tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout) ] @@ -548,7 +548,7 @@ class Sort(): input_dir = self.args.input_dir logger.info("Preparing to group...") if group_method == 'group_blur': - temp_list = [[img, self.estimate_blur(cv2_read_img(img, raise_error=True))] + temp_list = [[img, self.estimate_blur(read_image(img, raise_error=True))] for img in tqdm(self.find_images(input_dir), desc="Reloading", @@ -576,7 +576,7 @@ class Sort(): elif group_method == 'group_hist': temp_list = [ [img, - cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256])] + cv2.calcHist([read_image(img, raise_error=True)], [0], None, [256], [0, 256])] for img in tqdm(self.find_images(input_dir), desc="Reloading", @@ -632,7 +632,7 @@ class Sort(): Estimate the amount of blur an image has with the variance of the Laplacian. Normalize by pixel number to offset the effect of image size on pixel gradients & variance """ - image = cv2_read_img(image_file, raise_error=True) + image = read_image(image_file, raise_error=True) if image.ndim == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) blur_map = cv2.Laplacian(image, cv2.CV_32F)