mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 12:20:27 +01:00
* Core Updates
- Remove lib.utils.keras_backend_quiet and replace with get_backend() where relevant
- Document lib.gpu_stats and lib.sys_info
- Remove call to GPUStats.is_plaidml from convert and replace with get_backend()
- lib.gui.menu - typofix
* Update Dependencies
Bump Tensorflow Version Check
* Port extraction to tf2
* Add custom import finder for loading Keras or tf.keras depending on backend
* Add `tensorflow` to KerasFinder search path
* Basic TF2 training running
* model.initializers - docstring fix
* Fix and pass tests for tf2
* Replace Keras backend tests with faceswap backend tests
* Initial optimizers update
* Monkey patch tf.keras optimizer
* Remove custom Adam Optimizers and Memory Saving Gradients
* Remove multi-gpu option. Add Distribution to cli
* plugins.train.model._base: Add Mirror, Central and Default distribution strategies
* Update tensorboard kwargs for tf2
* Penalized Loss - Fix for TF2 and AMD
* Fix syntax for tf2.1
* requirements typo fix
* Explicit None for clipnorm if using a distribution strategy
* Fix penalized loss for distribution strategies
* Update Dlight
* typo fix
* Pin to TF2.2
* setup.py - Install tensorflow from pip if not available in Conda
* Add reduction options and set default for mirrored distribution strategy
* Explicitly use default strategy rather than nullcontext
* lib.model.backup_restore documentation
* Remove mirrored strategy reduction method and default based on OS
* Initial restructure - training
* Remove PingPong
Start model.base refactor
* Model saving and resuming enabled
* More tidying up of model.base
* Enable backup and snapshotting
* Re-enable state file
Remove loss names from state file
Fix print loss function
Set snapshot iterations correctly
* Revert original model to Keras Model structure rather than custom layer
Output full model and sub model summary
Change NNBlocks to callables rather than custom keras layers
* Apply custom Conv2D layer
* Finalize NNBlock restructure
Update Dfaker blocks
* Fix reloading model under a different distribution strategy
* Pass command line arguments through to trainer
* Remove training_opts from model and reference params directly
* Tidy up model __init__
* Re-enable tensorboard logging
Suppress "Model Not Compiled" warning
* Fix timelapse
* lib.model.nnblocks - Bugfix residual block
Port dfaker
bugfix original
* dfl-h128 ported
* DFL SAE ported
* IAE Ported
* dlight ported
* port lightweight
* realface ported
* unbalanced ported
* villain ported
* lib.cli.args - Update Batchsize + move allow_growth to config
* Remove output shape definition
Get image sizes per side rather than globally
* Strip mask input from encoder
* Fix learn mask and output learned mask to preview
* Trigger Allow Growth prior to setting strategy
* Fix GUI Graphing
* GUI - Display batchsize correctly + fix training graphs
* Fix penalized loss
* Enable mixed precision training
* Update analysis displayed batch to match input
* Penalized Loss - Multi-GPU Fix
* Fix all losses for TF2
* Fix Reflect Padding
* Allow different input size for each side of the model
* Fix conv-aware initialization on reload
* Switch allow_growth order
* Move mixed_precision to cli
* Remove distrubution strategies
* Compile penalized loss sub-function into LossContainer
* Bump default save interval to 250
Generate preview on first iteration but don't save
Fix iterations to start at 1 instead of 0
Remove training deprecation warnings
Bump some scripts.train loglevels
* Add ability to refresh preview on demand on pop-up window
* Enable refresh of training preview from GUI
* Fix Convert
Debug logging in Initializers
* Fix Preview Tool
* Update Legacy TF1 weights to TF2
Catch stats error on loading stats with missing logs
* lib.gui.popup_configure - Make more responsive + document
* Multiple Outputs supported in trainer
Original Model - Mask output bugfix
* Make universal inference model for convert
Remove scaling from penalized mask loss (now handled at input to y_true)
* Fix inference model to work properly with all models
* Fix multi-scale output for convert
* Fix clipnorm issue with distribution strategies
Edit error message on OOM
* Update plaidml losses
* Add missing file
* Disable gmsd loss for plaidnl
* PlaidML - Basic training working
* clipnorm rewriting for mixed-precision
* Inference model creation bugfixes
* Remove debug code
* Bugfix: Default clipnorm to 1.0
* Remove all mask inputs from training code
* Remove mask inputs from convert
* GUI - Analysis Tab - Docstrings
* Fix rate in totals row
* lib.gui - Only update display pages if they have focus
* Save the model on first iteration
* plaidml - Fix SSIM loss with penalized loss
* tools.alignments - Remove manual and fix jobs
* GUI - Remove case formatting on help text
* gui MultiSelect custom widget - Set default values on init
* vgg_face2 - Move to plugins.extract.recognition and use plugins._base base class
cli - Add global GPU Exclude Option
tools.sort - Use global GPU Exlude option for backend
lib.model.session - Exclude all GPUs when running in CPU mode
lib.cli.launcher - Set backend to CPU mode when all GPUs excluded
* Cascade excluded devices to GPU Stats
* Explicit GPU selection for Train and Convert
* Reduce Tensorflow Min GPU Multiprocessor Count to 4
* remove compat.v1 code from extract
* Force TF to skip mixed precision compatibility check if GPUs have been filtered
* Add notes to config for non-working AMD losses
* Rasie error if forcing extract to CPU mode
* Fix loading of legace dfl-sae weights + dfl-sae typo fix
* Remove unused requirements
Update sphinx requirements
Fix broken rst file locations
* docs: lib.gui.display
* clipnorm amd condition check
* documentation - gui.display_analysis
* Documentation - gui.popup_configure
* Documentation - lib.logger
* Documentation - lib.model.initializers
* Documentation - lib.model.layers
* Documentation - lib.model.losses
* Documentation - lib.model.nn_blocks
* Documetation - lib.model.normalization
* Documentation - lib.model.session
* Documentation - lib.plaidml_stats
* Documentation: lib.training_data
* Documentation: lib.utils
* Documentation: plugins.train.model._base
* GUI Stats: prevent stats from using GPU
* Documentation - Original Model
* Documentation: plugins.model.trainer._base
* linting
* unit tests: initializers + losses
* unit tests: nn_blocks
* bugfix - Exclude gpu devices in train, not include
* Enable Exclude-Gpus in Extract
* Enable exclude gpus in tools
* Disallow multiple plugin types in a single model folder
* Automatically add exclude_gpus argument in for cpu backends
* Cpu backend fixes
* Relax optimizer test threshold
* Default Train settings - Set mask to Extended
* Update Extractor cli help text
Update to Python 3.8
* Fix FAN to run on CPU
* lib.plaidml_tools - typofix
* Linux installer - check for curl
* linux installer - typo fix
415 lines
16 KiB
Python
415 lines
16 KiB
Python
#!/usr/bin python3
|
|
""" Main entry point to the training process of FaceSwap """
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
from threading import Lock
|
|
from time import sleep
|
|
|
|
import cv2
|
|
|
|
from lib.image import read_image
|
|
from lib.keypress import KBHit
|
|
from lib.multithreading import MultiThread
|
|
from lib.utils import (get_folder, get_image_paths, FaceswapError, _image_extensions)
|
|
from plugins.plugin_loader import PluginLoader
|
|
|
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class Train(): # pylint:disable=too-few-public-methods
|
|
""" The Faceswap Training Process.
|
|
|
|
The training process is responsible for training a model on a set of source faces and a set of
|
|
destination faces.
|
|
|
|
The training process is self contained and should not be referenced by any other scripts, so it
|
|
contains no public properties.
|
|
|
|
Parameters
|
|
----------
|
|
arguments: argparse.Namespace
|
|
The arguments to be passed to the training process as generated from Faceswap's command
|
|
line arguments
|
|
"""
|
|
def __init__(self, arguments):
|
|
logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
|
|
self._args = arguments
|
|
self._timelapse = self._set_timelapse()
|
|
self._images = self._get_images()
|
|
self._gui_preview_trigger = os.path.join(os.path.realpath(os.path.dirname(sys.argv[0])),
|
|
"lib", "gui", ".cache", ".preview_trigger")
|
|
self._stop = False
|
|
self._save_now = False
|
|
self._refresh_preview = False
|
|
self._preview_buffer = dict()
|
|
self._lock = Lock()
|
|
|
|
self.trainer_name = self._args.trainer
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
@property
|
|
def _image_size(self):
|
|
""" int: The training image size. Reads the first image in the training folder and returns
|
|
the size. """
|
|
image = read_image(self._images["a"][0], raise_error=True)
|
|
size = image.shape[0]
|
|
logger.debug("Training image size: %s", size)
|
|
return size
|
|
|
|
def _set_timelapse(self):
|
|
""" Set time-lapse paths if requested.
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
The time-lapse keyword arguments for passing to the trainer
|
|
|
|
"""
|
|
if (not self._args.timelapse_input_a and
|
|
not self._args.timelapse_input_b and
|
|
not self._args.timelapse_output):
|
|
return None
|
|
if (not self._args.timelapse_input_a or
|
|
not self._args.timelapse_input_b or
|
|
not self._args.timelapse_output):
|
|
raise FaceswapError("To enable the timelapse, you have to supply all the parameters "
|
|
"(--timelapse-input-A, --timelapse-input-B and "
|
|
"--timelapse-output).")
|
|
|
|
timelapse_output = str(get_folder(self._args.timelapse_output))
|
|
|
|
for folder in (self._args.timelapse_input_a, self._args.timelapse_input_b):
|
|
if folder is not None and not os.path.isdir(folder):
|
|
raise FaceswapError("The Timelapse path '{}' does not exist".format(folder))
|
|
exts = [os.path.splitext(fname)[-1] for fname in os.listdir(folder)]
|
|
if not any(ext in _image_extensions for ext in exts):
|
|
raise FaceswapError("The Timelapse path '{}' does not contain any valid "
|
|
"images".format(folder))
|
|
kwargs = {"input_a": self._args.timelapse_input_a,
|
|
"input_b": self._args.timelapse_input_b,
|
|
"output": timelapse_output}
|
|
logger.debug("Timelapse enabled: %s", kwargs)
|
|
return kwargs
|
|
|
|
def _get_images(self):
|
|
""" Check the image folders exist and contains images and obtain image paths.
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
The image paths for each side. The key is the side, the value is the list of paths
|
|
for that side.
|
|
"""
|
|
logger.debug("Getting image paths")
|
|
images = dict()
|
|
for side in ("a", "b"):
|
|
image_dir = getattr(self._args, "input_{}".format(side))
|
|
if not os.path.isdir(image_dir):
|
|
logger.error("Error: '%s' does not exist", image_dir)
|
|
sys.exit(1)
|
|
|
|
images[side] = get_image_paths(image_dir)
|
|
if not images[side]:
|
|
logger.error("Error: '%s' contains no images", image_dir)
|
|
sys.exit(1)
|
|
|
|
logger.info("Model A Directory: %s", self._args.input_a)
|
|
logger.info("Model B Directory: %s", self._args.input_b)
|
|
logger.debug("Got image paths: %s", [(key, str(len(val)) + " images")
|
|
for key, val in images.items()])
|
|
return images
|
|
|
|
def process(self):
|
|
""" The entry point for triggering the Training Process.
|
|
|
|
Should only be called from :class:`lib.cli.launcher.ScriptExecutor`
|
|
"""
|
|
logger.debug("Starting Training Process")
|
|
logger.info("Training data directory: %s", self._args.model_dir)
|
|
thread = self._start_thread()
|
|
# from lib.queue_manager import queue_manager; queue_manager.debug_monitor(1)
|
|
err = self._monitor(thread)
|
|
self._end_thread(thread, err)
|
|
logger.debug("Completed Training Process")
|
|
|
|
def _start_thread(self):
|
|
""" Put the :func:`_training` into a background thread so we can keep control.
|
|
|
|
Returns
|
|
-------
|
|
:class:`lib.multithreading.MultiThread`
|
|
The background thread for running training
|
|
"""
|
|
logger.debug("Launching Trainer thread")
|
|
thread = MultiThread(target=self._training)
|
|
thread.start()
|
|
logger.debug("Launched Trainer thread")
|
|
return thread
|
|
|
|
def _end_thread(self, thread, err):
|
|
""" Output message and join thread back to main on termination.
|
|
|
|
Parameters
|
|
----------
|
|
thread: :class:`lib.multithreading.MultiThread`
|
|
The background training thread
|
|
err: bool
|
|
Whether an error has been detected in :func:`_monitor`
|
|
"""
|
|
logger.debug("Ending Training thread")
|
|
if err:
|
|
msg = "Error caught! Exiting..."
|
|
log = logger.critical
|
|
else:
|
|
msg = ("Exit requested! The trainer will complete its current cycle, "
|
|
"save the models and quit (This can take a couple of minutes "
|
|
"depending on your training speed).")
|
|
if not self._args.redirect_gui:
|
|
msg += " If you want to kill it now, press Ctrl + c"
|
|
log = logger.info
|
|
log(msg)
|
|
self._stop = True
|
|
thread.join()
|
|
sys.stdout.flush()
|
|
logger.debug("Ended training thread")
|
|
|
|
def _training(self):
|
|
""" The training process to be run inside a thread. """
|
|
try:
|
|
sleep(1) # Let preview instructions flush out to logger
|
|
logger.debug("Commencing Training")
|
|
logger.info("Loading data, this may take a while...")
|
|
model = self._load_model()
|
|
trainer = self._load_trainer(model)
|
|
self._run_training_cycle(model, trainer)
|
|
except KeyboardInterrupt:
|
|
try:
|
|
logger.debug("Keyboard Interrupt Caught. Saving Weights and exiting")
|
|
model.save()
|
|
trainer.clear_tensorboard()
|
|
except KeyboardInterrupt:
|
|
logger.info("Saving model weights has been cancelled!")
|
|
sys.exit(0)
|
|
except Exception as err:
|
|
raise err
|
|
|
|
def _load_model(self):
|
|
""" Load the model requested for training.
|
|
|
|
Returns
|
|
-------
|
|
:file:`plugins.train.model` plugin
|
|
The requested model plugin
|
|
"""
|
|
logger.debug("Loading Model")
|
|
model_dir = str(get_folder(self._args.model_dir))
|
|
model = PluginLoader.get_model(self.trainer_name)(
|
|
model_dir,
|
|
self._args,
|
|
training_image_size=self._image_size,
|
|
predict=False)
|
|
model.build()
|
|
logger.debug("Loaded Model")
|
|
return model
|
|
|
|
def _load_trainer(self, model):
|
|
""" Load the trainer requested for training.
|
|
|
|
Parameters
|
|
----------
|
|
model: :file:`plugins.train.model` plugin
|
|
The requested model plugin
|
|
|
|
Returns
|
|
-------
|
|
:file:`plugins.train.trainer` plugin
|
|
The requested model trainer plugin
|
|
"""
|
|
logger.debug("Loading Trainer")
|
|
trainer = PluginLoader.get_trainer(model.trainer)
|
|
trainer = trainer(model,
|
|
self._images,
|
|
self._args.batch_size,
|
|
self._args.configfile)
|
|
logger.debug("Loaded Trainer")
|
|
return trainer
|
|
|
|
def _run_training_cycle(self, model, trainer):
|
|
""" Perform the training cycle.
|
|
|
|
Handles the background training, updating previews/time-lapse on each save interval,
|
|
and saving the model.
|
|
|
|
Parameters
|
|
----------
|
|
model: :file:`plugins.train.model` plugin
|
|
The requested model plugin
|
|
trainer: :file:`plugins.train.trainer` plugin
|
|
The requested model trainer plugin
|
|
"""
|
|
logger.debug("Running Training Cycle")
|
|
if self._args.write_image or self._args.redirect_gui or self._args.preview:
|
|
display_func = self._show
|
|
else:
|
|
display_func = None
|
|
|
|
for iteration in range(1, self._args.iterations + 1):
|
|
logger.trace("Training iteration: %s", iteration)
|
|
save_iteration = iteration % self._args.save_interval == 0 or iteration == 1
|
|
|
|
if save_iteration or self._save_now or self._refresh_preview:
|
|
viewer = display_func
|
|
else:
|
|
viewer = None
|
|
timelapse = self._timelapse if save_iteration else None
|
|
trainer.train_one_step(viewer, timelapse)
|
|
if self._stop:
|
|
logger.debug("Stop received. Terminating")
|
|
break
|
|
|
|
if self._refresh_preview and viewer is not None:
|
|
if self._args.redirect_gui:
|
|
print("\n")
|
|
logger.info("[Preview Updated]")
|
|
logger.debug("Removing gui trigger file: %s", self._gui_preview_trigger)
|
|
os.remove(self._gui_preview_trigger)
|
|
self._refresh_preview = False
|
|
|
|
if save_iteration:
|
|
logger.debug("Save Iteration: (iteration: %s", iteration)
|
|
model.save()
|
|
elif self._save_now:
|
|
logger.debug("Save Requested: (iteration: %s", iteration)
|
|
model.save()
|
|
self._save_now = False
|
|
logger.debug("Training cycle complete")
|
|
model.save()
|
|
trainer.clear_tensorboard()
|
|
self._stop = True
|
|
|
|
def _monitor(self, thread):
|
|
""" Monitor the background :func:`_training` thread for key presses and errors.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
``True`` if there has been an error in the background thread otherwise ``False``
|
|
"""
|
|
is_preview = self._args.preview
|
|
preview_trigger_set = False
|
|
logger.debug("Launching Monitor")
|
|
logger.info("===================================================")
|
|
logger.info(" Starting")
|
|
if is_preview:
|
|
logger.info(" Using live preview")
|
|
logger.info(" Press '%s' to save and quit",
|
|
"Stop" if self._args.redirect_gui or self._args.colab else "ENTER")
|
|
if not self._args.redirect_gui and not self._args.colab:
|
|
logger.info(" Press 'S' to save model weights immediately")
|
|
logger.info("===================================================")
|
|
|
|
keypress = KBHit(is_gui=self._args.redirect_gui)
|
|
err = False
|
|
while True:
|
|
try:
|
|
if is_preview:
|
|
with self._lock:
|
|
for name, image in self._preview_buffer.items():
|
|
cv2.imshow(name, image) # pylint: disable=no-member
|
|
cv_key = cv2.waitKey(1000) # pylint: disable=no-member
|
|
else:
|
|
cv_key = None
|
|
|
|
if thread.has_error:
|
|
logger.debug("Thread error detected")
|
|
err = True
|
|
break
|
|
if self._stop:
|
|
logger.debug("Stop received")
|
|
break
|
|
|
|
# Preview Monitor
|
|
if is_preview and (cv_key == ord("\n") or cv_key == ord("\r")):
|
|
logger.debug("Exit requested")
|
|
break
|
|
if is_preview and cv_key == ord("s"):
|
|
print("\n")
|
|
logger.info("Save requested")
|
|
self._save_now = True
|
|
if is_preview and cv_key == ord("r"):
|
|
print("\n")
|
|
logger.info("Refresh preview requested")
|
|
self._refresh_preview = True
|
|
|
|
# Console Monitor
|
|
if keypress.kbhit():
|
|
console_key = keypress.getch()
|
|
if console_key in ("\n", "\r"):
|
|
logger.debug("Exit requested")
|
|
break
|
|
if console_key in ("s", "S"):
|
|
logger.info("Save requested")
|
|
self._save_now = True
|
|
|
|
# GUI Preview trigger update monitor
|
|
if self._args.redirect_gui:
|
|
if not preview_trigger_set and os.path.isfile(self._gui_preview_trigger):
|
|
print("\n")
|
|
logger.info("Refresh preview requested")
|
|
self._refresh_preview = True
|
|
preview_trigger_set = True
|
|
|
|
if preview_trigger_set and not self._refresh_preview:
|
|
logger.debug("Resetting GUI preview trigger")
|
|
preview_trigger_set = False
|
|
|
|
sleep(1)
|
|
except KeyboardInterrupt:
|
|
logger.debug("Keyboard Interrupt received")
|
|
break
|
|
keypress.set_normal_term()
|
|
logger.debug("Closed Monitor")
|
|
return err
|
|
|
|
def _show(self, image, name=""):
|
|
""" Generate the preview and write preview file output.
|
|
|
|
Handles the output and display of preview images.
|
|
|
|
Parameters
|
|
----------
|
|
image: :class:`numpy.ndarray`
|
|
The preview image to be displayed and/or written out
|
|
name: str, optional
|
|
The name of the image for saving or display purposes. If an empty string is passed
|
|
then it will automatically be names. Default: ""
|
|
"""
|
|
logger.debug("Updating preview: (name: %s)", name)
|
|
try:
|
|
scriptpath = os.path.realpath(os.path.dirname(sys.argv[0]))
|
|
if self._args.write_image:
|
|
logger.debug("Saving preview to disk")
|
|
img = "training_preview.jpg"
|
|
imgfile = os.path.join(scriptpath, img)
|
|
cv2.imwrite(imgfile, image) # pylint: disable=no-member
|
|
logger.debug("Saved preview to: '%s'", img)
|
|
if self._args.redirect_gui:
|
|
logger.debug("Generating preview for GUI")
|
|
img = ".gui_training_preview.jpg"
|
|
imgfile = os.path.join(scriptpath, "lib", "gui",
|
|
".cache", "preview", img)
|
|
cv2.imwrite(imgfile, image) # pylint: disable=no-member
|
|
logger.debug("Generated preview for GUI: '%s'", img)
|
|
if self._args.preview:
|
|
logger.debug("Generating preview for display: '%s'", name)
|
|
with self._lock:
|
|
self._preview_buffer[name] = image
|
|
logger.debug("Generated preview for display: '%s'", name)
|
|
except Exception as err:
|
|
logging.error("could not preview sample")
|
|
raise err
|
|
logger.debug("Updated preview: (name: %s)", name)
|