mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 12:20:27 +01:00
Expose Augmentation Options to config
This commit is contained in:
parent
0db645ae41
commit
0e76422805
|
|
@ -21,12 +21,12 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|||
|
||||
class TrainingDataGenerator():
|
||||
""" Generate training data for models """
|
||||
def __init__(self, model_input_size, model_output_size, training_opts):
|
||||
def __init__(self, model_input_size, model_output_size, training_opts, config):
|
||||
logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, "
|
||||
"training_opts: %s, landmarks: %s)",
|
||||
"training_opts: %s, landmarks: %s, config: %s)",
|
||||
self.__class__.__name__, model_input_size, model_output_size,
|
||||
{key: val for key, val in training_opts.items() if key != "landmarks"},
|
||||
bool(training_opts.get("landmarks", None)))
|
||||
bool(training_opts.get("landmarks", None)), config)
|
||||
self.batchsize = 0
|
||||
self.model_input_size = model_input_size
|
||||
self.model_output_size = model_output_size
|
||||
|
|
@ -36,7 +36,8 @@ class TrainingDataGenerator():
|
|||
self._nearest_landmarks = None
|
||||
self.processing = ImageManipulation(model_input_size,
|
||||
model_output_size,
|
||||
training_opts.get("coverage_ratio", 0.625))
|
||||
training_opts.get("coverage_ratio", 0.625),
|
||||
config)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def set_mask_class(self):
|
||||
|
|
@ -214,18 +215,15 @@ class TrainingDataGenerator():
|
|||
|
||||
class ImageManipulation():
|
||||
""" Manipulations to be performed on training images """
|
||||
def __init__(self, input_size, output_size, coverage_ratio):
|
||||
def __init__(self, input_size, output_size, coverage_ratio, config):
|
||||
""" input_size: Size of the face input into the model
|
||||
output_size: Size of the face that comes out of the modell
|
||||
coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160
|
||||
"""
|
||||
logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s)",
|
||||
self.__class__.__name__, input_size, output_size, coverage_ratio)
|
||||
# Transform args
|
||||
self.rotation_range = 10 # Range to randomly rotate the image by
|
||||
self.zoom_range = 0.05 # Range to randomly zoom the image by
|
||||
self.shift_range = 0.05 # Range to randomly translate the image by
|
||||
self.random_flip = 0.5 # Chance to flip the image horizontally
|
||||
logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s, "
|
||||
"config: %s)", self.__class__.__name__, input_size, output_size,
|
||||
coverage_ratio, config)
|
||||
self.config = config
|
||||
# Transform and Warp args
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
|
|
@ -246,13 +244,15 @@ class ImageManipulation():
|
|||
img[:, :, :3] = face
|
||||
return img.astype('float32') / 255.0
|
||||
|
||||
@staticmethod
|
||||
def random_clahe(image):
|
||||
def random_clahe(self, image):
|
||||
""" Randomly perform Contrast Limited Adaptive Histogram Equilization """
|
||||
base_contrast = image.shape[0] // 128
|
||||
contrast_random = random()
|
||||
if contrast_random <= 0.5:
|
||||
contrast_adjustment = int((contrast_random * 10.0) * (base_contrast / 2))
|
||||
if contrast_random > self.config.get("color_clahe_chance", 50) / 100:
|
||||
return image
|
||||
|
||||
base_contrast = image.shape[0] // 128
|
||||
grid_base = random() * self.config.get("color_clahe_max_size", 4)
|
||||
contrast_adjustment = int(grid_base * (base_contrast / 2))
|
||||
grid_size = base_contrast + contrast_adjustment
|
||||
logger.trace("Adjusting Contrast. Grid Size: %s", grid_size)
|
||||
|
||||
|
|
@ -262,12 +262,14 @@ class ImageManipulation():
|
|||
image[:, :, chan] = clahe.apply(image[:, :, chan])
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def random_lab(image):
|
||||
""" Perform random color/lightness adjustment in LAB colorspace """
|
||||
randoms = [(random() * 0.6) - 0.3, # L adjust +/- 30%
|
||||
(random() * 0.16) - 0.08, # A adjust +/- 8%
|
||||
(random() * 0.16) - 0.08] # B adjust +/- 8%
|
||||
def random_lab(self, image):
|
||||
""" Perform random color/lightness adjustment in L*a*b* colorspace """
|
||||
amount_l = self.config.get("color_lightness", 30) / 100
|
||||
amount_ab = self.config.get("color_ab", 8) / 100
|
||||
|
||||
randoms = [(random() * amount_l * 2) - amount_l, # L adjust
|
||||
(random() * amount_ab * 2) - amount_ab, # A adjust
|
||||
(random() * amount_ab * 2) - amount_ab] # B adjust
|
||||
|
||||
logger.trace("Random LAB adjustments: %s", randoms)
|
||||
image = cv2.cvtColor( # pylint:disable=no-member
|
||||
|
|
@ -305,10 +307,15 @@ class ImageManipulation():
|
|||
logger.trace("Randomly transforming image")
|
||||
height, width = image.shape[0:2]
|
||||
|
||||
rotation = np.random.uniform(-self.rotation_range, self.rotation_range)
|
||||
scale = np.random.uniform(1 - self.zoom_range, 1 + self.zoom_range)
|
||||
tnx = np.random.uniform(-self.shift_range, self.shift_range) * width
|
||||
tny = np.random.uniform(-self.shift_range, self.shift_range) * height
|
||||
rotation_range = self.config.get("rotation_range", 10)
|
||||
rotation = np.random.uniform(-rotation_range, rotation_range)
|
||||
|
||||
zoom_range = self.config.get("zoom_range", 5) / 100
|
||||
scale = np.random.uniform(1 - zoom_range, 1 + zoom_range)
|
||||
|
||||
shift_range = self.config.get("shift_range", 5) / 100
|
||||
tnx = np.random.uniform(-shift_range, shift_range) * width
|
||||
tny = np.random.uniform(-shift_range, shift_range) * height
|
||||
|
||||
mat = cv2.getRotationMatrix2D( # pylint:disable=no-member
|
||||
(width // 2, height // 2), rotation, scale)
|
||||
|
|
@ -323,7 +330,8 @@ class ImageManipulation():
|
|||
def do_random_flip(self, image):
|
||||
""" Perform flip on image if random number is within threshold """
|
||||
logger.trace("Randomly flipping image")
|
||||
if np.random.random() < self.random_flip:
|
||||
random_flip = self.config.get("random_flip", 50) / 100
|
||||
if np.random.random() < random_flip:
|
||||
logger.trace("Flip within threshold. Flipping")
|
||||
retval = image[:, ::-1]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -74,10 +74,6 @@ class Config(FaceswapConfig):
|
|||
info="If using a mask, This penalizes the loss for the masked area, to give higher "
|
||||
"priority to the face area. \nShould increase overall quality and speed up "
|
||||
"training. This should probably be left at True")
|
||||
self.add_item(
|
||||
section=section, title="preview_images", datatype=int, default=14, min_max=(2, 16),
|
||||
rounding=2, fixed=False,
|
||||
info="Number of sample faces to display for each side in the preview when training.")
|
||||
logger.debug("Set global config")
|
||||
|
||||
def load_module(self, filename, module_path, plugin_type):
|
||||
|
|
@ -88,7 +84,8 @@ class Config(FaceswapConfig):
|
|||
section = ".".join((plugin_type, module.replace("_defaults", "")))
|
||||
logger.debug("Importing defaults module: %s.%s", module_path, module)
|
||||
mod = import_module("{}.{}".format(module_path, module))
|
||||
helptext = mod._HELPTEXT + ADDITIONAL_INFO # pylint:disable=protected-access
|
||||
helptext = mod._HELPTEXT # pylint:disable=protected-access
|
||||
helptext += ADDITIONAL_INFO if module_path.endswith("model") else ""
|
||||
self.add_section(title=section, info=helptext)
|
||||
for key, val in mod._DEFAULTS.items(): # pylint:disable=protected-access
|
||||
self.add_item(section=section, title=key, **val)
|
||||
|
|
|
|||
|
|
@ -152,9 +152,6 @@ class ModelBase():
|
|||
super() this method for defaults otherwise be sure to add """
|
||||
logger.debug("Setting training data")
|
||||
# Force number of preview images to between 2 and 16
|
||||
preview_images = self.config.get("preview_images", 14)
|
||||
preview_images = min(max(preview_images, 2), 16)
|
||||
self.training_opts["preview_images"] = preview_images
|
||||
self.training_opts["training_size"] = self.state.training_size
|
||||
self.training_opts["no_logs"] = self.state.current_session["no_logs"]
|
||||
self.training_opts["mask_type"] = self.config.get("mask_type", None)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
Set to None for not used
|
||||
no_logs: Disable tensorboard logging
|
||||
warp_to_landmarks: Use random_warp_landmarks instead of random_warp
|
||||
augment_color: Perform random shifting of LAB colors
|
||||
augment_color: Perform random shifting of L*a*b* colors
|
||||
no_flip: Don't perform a random flip on the image
|
||||
pingpong: Train each side seperately per save iteration rather than together
|
||||
"""
|
||||
|
|
@ -33,16 +33,23 @@ from lib.alignments import Alignments
|
|||
from lib.faces_detect import DetectedFace
|
||||
from lib.training_data import TrainingDataGenerator, stack_images
|
||||
from lib.utils import get_folder, get_image_paths
|
||||
from plugins.train._config import Config
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_config(plugin_name, configfile=None):
|
||||
""" Return the config for the requested model """
|
||||
return Config(plugin_name, configfile=configfile).config_dict
|
||||
|
||||
|
||||
class TrainerBase():
|
||||
""" Base Trainer """
|
||||
|
||||
def __init__(self, model, images, batch_size):
|
||||
def __init__(self, model, images, batch_size, configfile):
|
||||
logger.debug("Initializing %s: (model: '%s', batch_size: %s)",
|
||||
self.__class__.__name__, model, batch_size)
|
||||
self.config = get_config(".".join(self.__module__.split(".")[-2:]), configfile=configfile)
|
||||
self.batch_size = batch_size
|
||||
self.model = model
|
||||
self.model.state.add_session_batchsize(batch_size)
|
||||
|
|
@ -56,7 +63,8 @@ class TrainerBase():
|
|||
images[side],
|
||||
self.model,
|
||||
self.use_mask,
|
||||
batch_size)
|
||||
batch_size,
|
||||
self.config)
|
||||
for side in self.sides}
|
||||
|
||||
self.tensorboard = self.set_tensorboard()
|
||||
|
|
@ -67,6 +75,7 @@ class TrainerBase():
|
|||
self.timelapse = Timelapse(self.model,
|
||||
self.use_mask,
|
||||
self.model.training_opts["coverage_ratio"],
|
||||
self.config.get("preview_images", 14),
|
||||
self.batchers)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
|
|
@ -218,13 +227,14 @@ class TrainerBase():
|
|||
|
||||
class Batcher():
|
||||
""" Batch images from a single side """
|
||||
def __init__(self, side, images, model, use_mask, batch_size):
|
||||
logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s)",
|
||||
self.__class__.__name__, side, len(images), batch_size)
|
||||
def __init__(self, side, images, model, use_mask, batch_size, config):
|
||||
logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s, config: %s)",
|
||||
self.__class__.__name__, side, len(images), batch_size, config)
|
||||
self.model = model
|
||||
self.use_mask = use_mask
|
||||
self.side = side
|
||||
self.images = images
|
||||
self.config = config
|
||||
self.target = None
|
||||
self.samples = None
|
||||
self.mask = None
|
||||
|
|
@ -239,7 +249,10 @@ class Batcher():
|
|||
input_size = self.model.input_shape[0]
|
||||
output_size = self.model.output_shape[0]
|
||||
logger.debug("input_size: %s, output_size: %s", input_size, output_size)
|
||||
generator = TrainingDataGenerator(input_size, output_size, self.model.training_opts)
|
||||
generator = TrainingDataGenerator(input_size,
|
||||
output_size,
|
||||
self.model.training_opts,
|
||||
self.config)
|
||||
return generator
|
||||
|
||||
def train_one_batch(self, do_preview):
|
||||
|
|
@ -289,7 +302,9 @@ class Batcher():
|
|||
def set_preview_feed(self):
|
||||
""" Set the preview dictionary """
|
||||
logger.debug("Setting preview feed: (side: '%s')", self.side)
|
||||
batchsize = min(len(self.images), self.model.training_opts.get("preview_images", 14))
|
||||
preview_images = self.config.get("preview_images", 14)
|
||||
preview_images = min(max(preview_images, 2), 16)
|
||||
batchsize = min(len(self.images), preview_images)
|
||||
self.preview_feed = self.load_generator().minibatch_ab(self.images,
|
||||
batchsize,
|
||||
self.side,
|
||||
|
|
@ -299,7 +314,7 @@ class Batcher():
|
|||
|
||||
def compile_sample(self, batch_size, samples=None, images=None):
|
||||
""" Training samples to display in the viewer """
|
||||
num_images = self.model.training_opts.get("preview_images", 14)
|
||||
num_images = self.config.get("preview_images", 14)
|
||||
num_images = min(batch_size, num_images) if batch_size is not None else num_images
|
||||
logger.debug("Compiling samples: (side: '%s', samples: %s)", self.side, num_images)
|
||||
images = images if images is not None else self.target
|
||||
|
|
@ -561,10 +576,11 @@ class Samples():
|
|||
|
||||
class Timelapse():
|
||||
""" Create the timelapse """
|
||||
def __init__(self, model, use_mask, coverage_ratio, batchers):
|
||||
def __init__(self, model, use_mask, coverage_ratio, preview_images, batchers):
|
||||
logger.debug("Initializing %s: model: %s, use_mask: %s, coverage_ratio: %s, "
|
||||
"batchers: '%s')", self.__class__.__name__, model, use_mask,
|
||||
coverage_ratio, batchers)
|
||||
"preview_images: %s, batchers: '%s')", self.__class__.__name__, model,
|
||||
use_mask, coverage_ratio, preview_images, batchers)
|
||||
self.preview_images = preview_images
|
||||
self.samples = Samples(model, use_mask, coverage_ratio)
|
||||
self.model = model
|
||||
self.batchers = batchers
|
||||
|
|
@ -591,7 +607,7 @@ class Timelapse():
|
|||
images = {"a": get_image_paths(input_a), "b": get_image_paths(input_b)}
|
||||
batchsize = min(len(images["a"]),
|
||||
len(images["b"]),
|
||||
self.model.training_opts.get("preview_images", 14))
|
||||
self.preview_images)
|
||||
for side, image_files in images.items():
|
||||
self.batchers[side].set_timelapse_feed(image_files, batchsize)
|
||||
logger.debug("Set up timelapse")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Original Trainer """
|
||||
|
||||
from ._base import TrainerBase as Trainer
|
||||
from ._base import TrainerBase
|
||||
|
||||
|
||||
class Trainer(TrainerBase):
|
||||
""" Original is currently identical to Base """
|
||||
def __init__(self, *args, **kwargs): # pylint:disable=useless-super-delegation
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
|
|||
125
plugins/train/trainer/original_defaults.py
Executable file
125
plugins/train/trainer/original_defaults.py
Executable file
|
|
@ -0,0 +1,125 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
The default options for the faceswap Original Model plugin.
|
||||
|
||||
Defaults files should be named <plugin_name>_defaults.py
|
||||
Any items placed into this file will automatically get added to the relevant config .ini files
|
||||
within the faceswap/config folder.
|
||||
|
||||
The following variables should be defined:
|
||||
_HELPTEXT: A string describing what this plugin does
|
||||
_DEFAULTS: A dictionary containing the options, defaults and meta information. The
|
||||
dictionary should be defined as:
|
||||
{<option_name>: {<metadata>}}
|
||||
|
||||
<option_name> should always be lower text.
|
||||
<metadata> dictionary requirements are listed below.
|
||||
|
||||
The following keys are expected for the _DEFAULTS <metadata> dict:
|
||||
datatype: [required] A python type class. This limits the type of data that can be
|
||||
provided in the .ini file and ensures that the value is returned in the
|
||||
correct type to faceswap. Valid datatypes are: <class 'int'>, <class 'float'>,
|
||||
<class 'str'>, <class 'bool'>.
|
||||
default: [required] The default value for this option.
|
||||
info: [required] A string describing what this option does.
|
||||
choices: [optional] If this option's datatype is of <class 'str'> then valid
|
||||
selections can be defined here. This validates the option and also enables
|
||||
a combobox / radio option in the GUI.
|
||||
gui_radio: [optional] If <choices> are defined, this indicates that the GUI should use
|
||||
radio buttons rather than a combobox to display this option.
|
||||
min_max: [partial] For <class 'int'> and <class 'float'> datatypes this is required
|
||||
otherwise it is ignored. Should be a tuple of min and max accepted values.
|
||||
This is used for controlling the GUI slider range. Values are not enforced.
|
||||
rounding: [partial] For <class 'int'> and <class 'float'> datatypes this is
|
||||
required otherwise it is ignored. Used for the GUI slider. For floats, this
|
||||
is the number of decimal places to display. For ints this is the step size.
|
||||
fixed: [optional] [train only]. Training configurations are fixed when the model is
|
||||
created, and then reloaded from the state file. Marking an item as fixed=False
|
||||
indicates that this value can be changed for existing models, and will override
|
||||
the value saved in the state file with the updated value in config. If not
|
||||
provided this will default to True.
|
||||
"""
|
||||
|
||||
|
||||
_HELPTEXT = ("Original Trainer Options.\n"
|
||||
"WARNING: The defaults for augmentation will be fine for 99.9% of use cases. "
|
||||
"Only change them if you absolutely know what you are doing!")
|
||||
|
||||
|
||||
_DEFAULTS = {
|
||||
"preview_images": {
|
||||
"default": 14,
|
||||
"info": "Number of sample faces to display for each side in the preview when training.",
|
||||
"datatype": int,
|
||||
"rounding": 2,
|
||||
"min_max": (2, 16),
|
||||
},
|
||||
"zoom_amount": {
|
||||
"default": 5,
|
||||
"info": "Percentage amount to randomly zoom each training image in and out.",
|
||||
"datatype": int,
|
||||
"rounding": 1,
|
||||
"min_max": (0, 25),
|
||||
},
|
||||
"rotation_range": {
|
||||
"default": 10,
|
||||
"info": "Percentage amount to randomly rotate each training image.",
|
||||
"datatype": int,
|
||||
"rounding": 1,
|
||||
"min_max": (0, 25),
|
||||
},
|
||||
"shift_range": {
|
||||
"default": 5,
|
||||
"info": "Percentage amount to randomly shift each training image horizontally and "
|
||||
"vertically.",
|
||||
"datatype": int,
|
||||
"rounding": 1,
|
||||
"min_max": (0, 0),
|
||||
},
|
||||
"flip_chance": {
|
||||
"default": 50,
|
||||
"info": "Percentage chance to randomly flip each training image horizontally.\n"
|
||||
"NB: This is ignored if the 'no-flip' option is enabled",
|
||||
"datatype": int,
|
||||
"rounding": 1,
|
||||
"min_max": (0, 75),
|
||||
},
|
||||
"color_lightness": {
|
||||
"default": 30,
|
||||
"info": "Percentage amount to randomly alter the lightness of each training image.\n"
|
||||
"NB: This is ignored if the 'no-augment-color' option is enabled",
|
||||
"datatype": int,
|
||||
"rounding": 1,
|
||||
"min_max": (0, 75),
|
||||
},
|
||||
"color_ab": {
|
||||
"default": 8,
|
||||
"info": "Percentage amount to randomly alter the 'a' and 'b' colors of the L*a*b* color "
|
||||
"space of each training image.\n"
|
||||
"NB: This is ignored if the 'no-augment-color' option is enabled",
|
||||
"datatype": int,
|
||||
"rounding": 1,
|
||||
"min_max": (0, 50),
|
||||
},
|
||||
"color_clahe_chance": {
|
||||
"default": 50,
|
||||
"info": "Percentage chance to perform Contrast Limited Adaptive Histogram Equalization on "
|
||||
"each training image.\n"
|
||||
"NB: This is ignored if the 'no-augment-color' option is enabled",
|
||||
"datatype": int,
|
||||
"rounding": 1,
|
||||
"min_max": (0, 75),
|
||||
"fixed": False,
|
||||
},
|
||||
"color_clahe_max_size": {
|
||||
"default": 4,
|
||||
"info": "The grid size dictates how much Contrast Limited Adaptive Histogram Equalization "
|
||||
"is performed on any training image selected for clahe. Contrast will be applied "
|
||||
"randomly with a gridsize of 0 up to the maximum. This value is a multiplier "
|
||||
"calculated from the training image size.\n"
|
||||
"NB: This is ignored if the 'no-augment-color' option is enabled",
|
||||
"datatype": int,
|
||||
"rounding": 1,
|
||||
"min_max": (1, 8),
|
||||
},
|
||||
}
|
||||
|
|
@ -198,7 +198,8 @@ class Train():
|
|||
trainer = PluginLoader.get_trainer(model.trainer)
|
||||
trainer = trainer(model,
|
||||
self.images,
|
||||
self.args.batch_size)
|
||||
self.args.batch_size,
|
||||
self.args.configfile)
|
||||
logger.debug("Loaded Trainer")
|
||||
return trainer
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user