Expose Augmentation Options to config

This commit is contained in:
torzdf 2019-06-26 10:48:38 +00:00
parent 0db645ae41
commit 0e76422805
7 changed files with 207 additions and 57 deletions

View File

@ -21,12 +21,12 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class TrainingDataGenerator(): class TrainingDataGenerator():
""" Generate training data for models """ """ Generate training data for models """
def __init__(self, model_input_size, model_output_size, training_opts): def __init__(self, model_input_size, model_output_size, training_opts, config):
logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, " logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, "
"training_opts: %s, landmarks: %s)", "training_opts: %s, landmarks: %s, config: %s)",
self.__class__.__name__, model_input_size, model_output_size, self.__class__.__name__, model_input_size, model_output_size,
{key: val for key, val in training_opts.items() if key != "landmarks"}, {key: val for key, val in training_opts.items() if key != "landmarks"},
bool(training_opts.get("landmarks", None))) bool(training_opts.get("landmarks", None)), config)
self.batchsize = 0 self.batchsize = 0
self.model_input_size = model_input_size self.model_input_size = model_input_size
self.model_output_size = model_output_size self.model_output_size = model_output_size
@ -36,7 +36,8 @@ class TrainingDataGenerator():
self._nearest_landmarks = None self._nearest_landmarks = None
self.processing = ImageManipulation(model_input_size, self.processing = ImageManipulation(model_input_size,
model_output_size, model_output_size,
training_opts.get("coverage_ratio", 0.625)) training_opts.get("coverage_ratio", 0.625),
config)
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def set_mask_class(self): def set_mask_class(self):
@ -214,18 +215,15 @@ class TrainingDataGenerator():
class ImageManipulation(): class ImageManipulation():
""" Manipulations to be performed on training images """ """ Manipulations to be performed on training images """
def __init__(self, input_size, output_size, coverage_ratio): def __init__(self, input_size, output_size, coverage_ratio, config):
""" input_size: Size of the face input into the model """ input_size: Size of the face input into the model
output_size: Size of the face that comes out of the modell output_size: Size of the face that comes out of the modell
coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160 coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160
""" """
logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s)", logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s, "
self.__class__.__name__, input_size, output_size, coverage_ratio) "config: %s)", self.__class__.__name__, input_size, output_size,
# Transform args coverage_ratio, config)
self.rotation_range = 10 # Range to randomly rotate the image by self.config = config
self.zoom_range = 0.05 # Range to randomly zoom the image by
self.shift_range = 0.05 # Range to randomly translate the image by
self.random_flip = 0.5 # Chance to flip the image horizontally
# Transform and Warp args # Transform and Warp args
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
@ -246,28 +244,32 @@ class ImageManipulation():
img[:, :, :3] = face img[:, :, :3] = face
return img.astype('float32') / 255.0 return img.astype('float32') / 255.0
@staticmethod def random_clahe(self, image):
def random_clahe(image):
""" Randomly perform Contrast Limited Adaptive Histogram Equilization """ """ Randomly perform Contrast Limited Adaptive Histogram Equilization """
base_contrast = image.shape[0] // 128
contrast_random = random() contrast_random = random()
if contrast_random <= 0.5: if contrast_random > self.config.get("color_clahe_chance", 50) / 100:
contrast_adjustment = int((contrast_random * 10.0) * (base_contrast / 2)) return image
grid_size = base_contrast + contrast_adjustment
logger.trace("Adjusting Contrast. Grid Size: %s", grid_size)
clahe = cv2.createCLAHE(clipLimit=2.0, # pylint: disable=no-member base_contrast = image.shape[0] // 128
tileGridSize=(grid_size, grid_size)) grid_base = random() * self.config.get("color_clahe_max_size", 4)
for chan in range(3): contrast_adjustment = int(grid_base * (base_contrast / 2))
image[:, :, chan] = clahe.apply(image[:, :, chan]) grid_size = base_contrast + contrast_adjustment
logger.trace("Adjusting Contrast. Grid Size: %s", grid_size)
clahe = cv2.createCLAHE(clipLimit=2.0, # pylint: disable=no-member
tileGridSize=(grid_size, grid_size))
for chan in range(3):
image[:, :, chan] = clahe.apply(image[:, :, chan])
return image return image
@staticmethod def random_lab(self, image):
def random_lab(image): """ Perform random color/lightness adjustment in L*a*b* colorspace """
""" Perform random color/lightness adjustment in LAB colorspace """ amount_l = self.config.get("color_lightness", 30) / 100
randoms = [(random() * 0.6) - 0.3, # L adjust +/- 30% amount_ab = self.config.get("color_ab", 8) / 100
(random() * 0.16) - 0.08, # A adjust +/- 8%
(random() * 0.16) - 0.08] # B adjust +/- 8% randoms = [(random() * amount_l * 2) - amount_l, # L adjust
(random() * amount_ab * 2) - amount_ab, # A adjust
(random() * amount_ab * 2) - amount_ab] # B adjust
logger.trace("Random LAB adjustments: %s", randoms) logger.trace("Random LAB adjustments: %s", randoms)
image = cv2.cvtColor( # pylint:disable=no-member image = cv2.cvtColor( # pylint:disable=no-member
@ -305,10 +307,15 @@ class ImageManipulation():
logger.trace("Randomly transforming image") logger.trace("Randomly transforming image")
height, width = image.shape[0:2] height, width = image.shape[0:2]
rotation = np.random.uniform(-self.rotation_range, self.rotation_range) rotation_range = self.config.get("rotation_range", 10)
scale = np.random.uniform(1 - self.zoom_range, 1 + self.zoom_range) rotation = np.random.uniform(-rotation_range, rotation_range)
tnx = np.random.uniform(-self.shift_range, self.shift_range) * width
tny = np.random.uniform(-self.shift_range, self.shift_range) * height zoom_range = self.config.get("zoom_range", 5) / 100
scale = np.random.uniform(1 - zoom_range, 1 + zoom_range)
shift_range = self.config.get("shift_range", 5) / 100
tnx = np.random.uniform(-shift_range, shift_range) * width
tny = np.random.uniform(-shift_range, shift_range) * height
mat = cv2.getRotationMatrix2D( # pylint:disable=no-member mat = cv2.getRotationMatrix2D( # pylint:disable=no-member
(width // 2, height // 2), rotation, scale) (width // 2, height // 2), rotation, scale)
@ -323,7 +330,8 @@ class ImageManipulation():
def do_random_flip(self, image): def do_random_flip(self, image):
""" Perform flip on image if random number is within threshold """ """ Perform flip on image if random number is within threshold """
logger.trace("Randomly flipping image") logger.trace("Randomly flipping image")
if np.random.random() < self.random_flip: random_flip = self.config.get("random_flip", 50) / 100
if np.random.random() < random_flip:
logger.trace("Flip within threshold. Flipping") logger.trace("Flip within threshold. Flipping")
retval = image[:, ::-1] retval = image[:, ::-1]
else: else:

View File

@ -74,10 +74,6 @@ class Config(FaceswapConfig):
info="If using a mask, This penalizes the loss for the masked area, to give higher " info="If using a mask, This penalizes the loss for the masked area, to give higher "
"priority to the face area. \nShould increase overall quality and speed up " "priority to the face area. \nShould increase overall quality and speed up "
"training. This should probably be left at True") "training. This should probably be left at True")
self.add_item(
section=section, title="preview_images", datatype=int, default=14, min_max=(2, 16),
rounding=2, fixed=False,
info="Number of sample faces to display for each side in the preview when training.")
logger.debug("Set global config") logger.debug("Set global config")
def load_module(self, filename, module_path, plugin_type): def load_module(self, filename, module_path, plugin_type):
@ -88,7 +84,8 @@ class Config(FaceswapConfig):
section = ".".join((plugin_type, module.replace("_defaults", ""))) section = ".".join((plugin_type, module.replace("_defaults", "")))
logger.debug("Importing defaults module: %s.%s", module_path, module) logger.debug("Importing defaults module: %s.%s", module_path, module)
mod = import_module("{}.{}".format(module_path, module)) mod = import_module("{}.{}".format(module_path, module))
helptext = mod._HELPTEXT + ADDITIONAL_INFO # pylint:disable=protected-access helptext = mod._HELPTEXT # pylint:disable=protected-access
helptext += ADDITIONAL_INFO if module_path.endswith("model") else ""
self.add_section(title=section, info=helptext) self.add_section(title=section, info=helptext)
for key, val in mod._DEFAULTS.items(): # pylint:disable=protected-access for key, val in mod._DEFAULTS.items(): # pylint:disable=protected-access
self.add_item(section=section, title=key, **val) self.add_item(section=section, title=key, **val)

View File

@ -152,9 +152,6 @@ class ModelBase():
super() this method for defaults otherwise be sure to add """ super() this method for defaults otherwise be sure to add """
logger.debug("Setting training data") logger.debug("Setting training data")
# Force number of preview images to between 2 and 16 # Force number of preview images to between 2 and 16
preview_images = self.config.get("preview_images", 14)
preview_images = min(max(preview_images, 2), 16)
self.training_opts["preview_images"] = preview_images
self.training_opts["training_size"] = self.state.training_size self.training_opts["training_size"] = self.state.training_size
self.training_opts["no_logs"] = self.state.current_session["no_logs"] self.training_opts["no_logs"] = self.state.current_session["no_logs"]
self.training_opts["mask_type"] = self.config.get("mask_type", None) self.training_opts["mask_type"] = self.config.get("mask_type", None)

View File

@ -15,7 +15,7 @@
Set to None for not used Set to None for not used
no_logs: Disable tensorboard logging no_logs: Disable tensorboard logging
warp_to_landmarks: Use random_warp_landmarks instead of random_warp warp_to_landmarks: Use random_warp_landmarks instead of random_warp
augment_color: Perform random shifting of LAB colors augment_color: Perform random shifting of L*a*b* colors
no_flip: Don't perform a random flip on the image no_flip: Don't perform a random flip on the image
pingpong: Train each side seperately per save iteration rather than together pingpong: Train each side seperately per save iteration rather than together
""" """
@ -33,16 +33,23 @@ from lib.alignments import Alignments
from lib.faces_detect import DetectedFace from lib.faces_detect import DetectedFace
from lib.training_data import TrainingDataGenerator, stack_images from lib.training_data import TrainingDataGenerator, stack_images
from lib.utils import get_folder, get_image_paths from lib.utils import get_folder, get_image_paths
from plugins.train._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def get_config(plugin_name, configfile=None):
""" Return the config for the requested model """
return Config(plugin_name, configfile=configfile).config_dict
class TrainerBase(): class TrainerBase():
""" Base Trainer """ """ Base Trainer """
def __init__(self, model, images, batch_size): def __init__(self, model, images, batch_size, configfile):
logger.debug("Initializing %s: (model: '%s', batch_size: %s)", logger.debug("Initializing %s: (model: '%s', batch_size: %s)",
self.__class__.__name__, model, batch_size) self.__class__.__name__, model, batch_size)
self.config = get_config(".".join(self.__module__.split(".")[-2:]), configfile=configfile)
self.batch_size = batch_size self.batch_size = batch_size
self.model = model self.model = model
self.model.state.add_session_batchsize(batch_size) self.model.state.add_session_batchsize(batch_size)
@ -56,7 +63,8 @@ class TrainerBase():
images[side], images[side],
self.model, self.model,
self.use_mask, self.use_mask,
batch_size) batch_size,
self.config)
for side in self.sides} for side in self.sides}
self.tensorboard = self.set_tensorboard() self.tensorboard = self.set_tensorboard()
@ -67,6 +75,7 @@ class TrainerBase():
self.timelapse = Timelapse(self.model, self.timelapse = Timelapse(self.model,
self.use_mask, self.use_mask,
self.model.training_opts["coverage_ratio"], self.model.training_opts["coverage_ratio"],
self.config.get("preview_images", 14),
self.batchers) self.batchers)
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@ -218,13 +227,14 @@ class TrainerBase():
class Batcher(): class Batcher():
""" Batch images from a single side """ """ Batch images from a single side """
def __init__(self, side, images, model, use_mask, batch_size): def __init__(self, side, images, model, use_mask, batch_size, config):
logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s)", logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s, config: %s)",
self.__class__.__name__, side, len(images), batch_size) self.__class__.__name__, side, len(images), batch_size, config)
self.model = model self.model = model
self.use_mask = use_mask self.use_mask = use_mask
self.side = side self.side = side
self.images = images self.images = images
self.config = config
self.target = None self.target = None
self.samples = None self.samples = None
self.mask = None self.mask = None
@ -239,7 +249,10 @@ class Batcher():
input_size = self.model.input_shape[0] input_size = self.model.input_shape[0]
output_size = self.model.output_shape[0] output_size = self.model.output_shape[0]
logger.debug("input_size: %s, output_size: %s", input_size, output_size) logger.debug("input_size: %s, output_size: %s", input_size, output_size)
generator = TrainingDataGenerator(input_size, output_size, self.model.training_opts) generator = TrainingDataGenerator(input_size,
output_size,
self.model.training_opts,
self.config)
return generator return generator
def train_one_batch(self, do_preview): def train_one_batch(self, do_preview):
@ -289,7 +302,9 @@ class Batcher():
def set_preview_feed(self): def set_preview_feed(self):
""" Set the preview dictionary """ """ Set the preview dictionary """
logger.debug("Setting preview feed: (side: '%s')", self.side) logger.debug("Setting preview feed: (side: '%s')", self.side)
batchsize = min(len(self.images), self.model.training_opts.get("preview_images", 14)) preview_images = self.config.get("preview_images", 14)
preview_images = min(max(preview_images, 2), 16)
batchsize = min(len(self.images), preview_images)
self.preview_feed = self.load_generator().minibatch_ab(self.images, self.preview_feed = self.load_generator().minibatch_ab(self.images,
batchsize, batchsize,
self.side, self.side,
@ -299,7 +314,7 @@ class Batcher():
def compile_sample(self, batch_size, samples=None, images=None): def compile_sample(self, batch_size, samples=None, images=None):
""" Training samples to display in the viewer """ """ Training samples to display in the viewer """
num_images = self.model.training_opts.get("preview_images", 14) num_images = self.config.get("preview_images", 14)
num_images = min(batch_size, num_images) if batch_size is not None else num_images num_images = min(batch_size, num_images) if batch_size is not None else num_images
logger.debug("Compiling samples: (side: '%s', samples: %s)", self.side, num_images) logger.debug("Compiling samples: (side: '%s', samples: %s)", self.side, num_images)
images = images if images is not None else self.target images = images if images is not None else self.target
@ -561,10 +576,11 @@ class Samples():
class Timelapse(): class Timelapse():
""" Create the timelapse """ """ Create the timelapse """
def __init__(self, model, use_mask, coverage_ratio, batchers): def __init__(self, model, use_mask, coverage_ratio, preview_images, batchers):
logger.debug("Initializing %s: model: %s, use_mask: %s, coverage_ratio: %s, " logger.debug("Initializing %s: model: %s, use_mask: %s, coverage_ratio: %s, "
"batchers: '%s')", self.__class__.__name__, model, use_mask, "preview_images: %s, batchers: '%s')", self.__class__.__name__, model,
coverage_ratio, batchers) use_mask, coverage_ratio, preview_images, batchers)
self.preview_images = preview_images
self.samples = Samples(model, use_mask, coverage_ratio) self.samples = Samples(model, use_mask, coverage_ratio)
self.model = model self.model = model
self.batchers = batchers self.batchers = batchers
@ -591,7 +607,7 @@ class Timelapse():
images = {"a": get_image_paths(input_a), "b": get_image_paths(input_b)} images = {"a": get_image_paths(input_a), "b": get_image_paths(input_b)}
batchsize = min(len(images["a"]), batchsize = min(len(images["a"]),
len(images["b"]), len(images["b"]),
self.model.training_opts.get("preview_images", 14)) self.preview_images)
for side, image_files in images.items(): for side, image_files in images.items():
self.batchers[side].set_timelapse_feed(image_files, batchsize) self.batchers[side].set_timelapse_feed(image_files, batchsize)
logger.debug("Set up timelapse") logger.debug("Set up timelapse")

View File

@ -1,4 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Original Trainer """ """ Original Trainer """
from ._base import TrainerBase as Trainer from ._base import TrainerBase
class Trainer(TrainerBase):
""" Original is currently identical to Base """
def __init__(self, *args, **kwargs): # pylint:disable=useless-super-delegation
super().__init__(*args, **kwargs)

View File

@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""
The default options for the faceswap Original Model plugin.
Defaults files should be named <plugin_name>_defaults.py
Any items placed into this file will automatically get added to the relevant config .ini files
within the faceswap/config folder.
The following variables should be defined:
_HELPTEXT: A string describing what this plugin does
_DEFAULTS: A dictionary containing the options, defaults and meta information. The
dictionary should be defined as:
{<option_name>: {<metadata>}}
<option_name> should always be lower text.
<metadata> dictionary requirements are listed below.
The following keys are expected for the _DEFAULTS <metadata> dict:
datatype: [required] A python type class. This limits the type of data that can be
provided in the .ini file and ensures that the value is returned in the
correct type to faceswap. Valid datatypes are: <class 'int'>, <class 'float'>,
<class 'str'>, <class 'bool'>.
default: [required] The default value for this option.
info: [required] A string describing what this option does.
choices: [optional] If this option's datatype is of <class 'str'> then valid
selections can be defined here. This validates the option and also enables
a combobox / radio option in the GUI.
gui_radio: [optional] If <choices> are defined, this indicates that the GUI should use
radio buttons rather than a combobox to display this option.
min_max: [partial] For <class 'int'> and <class 'float'> datatypes this is required
otherwise it is ignored. Should be a tuple of min and max accepted values.
This is used for controlling the GUI slider range. Values are not enforced.
rounding: [partial] For <class 'int'> and <class 'float'> datatypes this is
required otherwise it is ignored. Used for the GUI slider. For floats, this
is the number of decimal places to display. For ints this is the step size.
fixed: [optional] [train only]. Training configurations are fixed when the model is
created, and then reloaded from the state file. Marking an item as fixed=False
indicates that this value can be changed for existing models, and will override
the value saved in the state file with the updated value in config. If not
provided this will default to True.
"""
_HELPTEXT = ("Original Trainer Options.\n"
"WARNING: The defaults for augmentation will be fine for 99.9% of use cases. "
"Only change them if you absolutely know what you are doing!")
_DEFAULTS = {
"preview_images": {
"default": 14,
"info": "Number of sample faces to display for each side in the preview when training.",
"datatype": int,
"rounding": 2,
"min_max": (2, 16),
},
"zoom_amount": {
"default": 5,
"info": "Percentage amount to randomly zoom each training image in and out.",
"datatype": int,
"rounding": 1,
"min_max": (0, 25),
},
"rotation_range": {
"default": 10,
"info": "Percentage amount to randomly rotate each training image.",
"datatype": int,
"rounding": 1,
"min_max": (0, 25),
},
"shift_range": {
"default": 5,
"info": "Percentage amount to randomly shift each training image horizontally and "
"vertically.",
"datatype": int,
"rounding": 1,
"min_max": (0, 0),
},
"flip_chance": {
"default": 50,
"info": "Percentage chance to randomly flip each training image horizontally.\n"
"NB: This is ignored if the 'no-flip' option is enabled",
"datatype": int,
"rounding": 1,
"min_max": (0, 75),
},
"color_lightness": {
"default": 30,
"info": "Percentage amount to randomly alter the lightness of each training image.\n"
"NB: This is ignored if the 'no-augment-color' option is enabled",
"datatype": int,
"rounding": 1,
"min_max": (0, 75),
},
"color_ab": {
"default": 8,
"info": "Percentage amount to randomly alter the 'a' and 'b' colors of the L*a*b* color "
"space of each training image.\n"
"NB: This is ignored if the 'no-augment-color' option is enabled",
"datatype": int,
"rounding": 1,
"min_max": (0, 50),
},
"color_clahe_chance": {
"default": 50,
"info": "Percentage chance to perform Contrast Limited Adaptive Histogram Equalization on "
"each training image.\n"
"NB: This is ignored if the 'no-augment-color' option is enabled",
"datatype": int,
"rounding": 1,
"min_max": (0, 75),
"fixed": False,
},
"color_clahe_max_size": {
"default": 4,
"info": "The grid size dictates how much Contrast Limited Adaptive Histogram Equalization "
"is performed on any training image selected for clahe. Contrast will be applied "
"randomly with a gridsize of 0 up to the maximum. This value is a multiplier "
"calculated from the training image size.\n"
"NB: This is ignored if the 'no-augment-color' option is enabled",
"datatype": int,
"rounding": 1,
"min_max": (1, 8),
},
}

View File

@ -198,7 +198,8 @@ class Train():
trainer = PluginLoader.get_trainer(model.trainer) trainer = PluginLoader.get_trainer(model.trainer)
trainer = trainer(model, trainer = trainer(model,
self.images, self.images,
self.args.batch_size) self.args.batch_size,
self.args.configfile)
logger.debug("Loaded Trainer") logger.debug("Loaded Trainer")
return trainer return trainer