mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 12:20:27 +01:00
Expose Augmentation Options to config
This commit is contained in:
parent
0db645ae41
commit
0e76422805
|
|
@ -21,12 +21,12 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
class TrainingDataGenerator():
|
class TrainingDataGenerator():
|
||||||
""" Generate training data for models """
|
""" Generate training data for models """
|
||||||
def __init__(self, model_input_size, model_output_size, training_opts):
|
def __init__(self, model_input_size, model_output_size, training_opts, config):
|
||||||
logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, "
|
logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, "
|
||||||
"training_opts: %s, landmarks: %s)",
|
"training_opts: %s, landmarks: %s, config: %s)",
|
||||||
self.__class__.__name__, model_input_size, model_output_size,
|
self.__class__.__name__, model_input_size, model_output_size,
|
||||||
{key: val for key, val in training_opts.items() if key != "landmarks"},
|
{key: val for key, val in training_opts.items() if key != "landmarks"},
|
||||||
bool(training_opts.get("landmarks", None)))
|
bool(training_opts.get("landmarks", None)), config)
|
||||||
self.batchsize = 0
|
self.batchsize = 0
|
||||||
self.model_input_size = model_input_size
|
self.model_input_size = model_input_size
|
||||||
self.model_output_size = model_output_size
|
self.model_output_size = model_output_size
|
||||||
|
|
@ -36,7 +36,8 @@ class TrainingDataGenerator():
|
||||||
self._nearest_landmarks = None
|
self._nearest_landmarks = None
|
||||||
self.processing = ImageManipulation(model_input_size,
|
self.processing = ImageManipulation(model_input_size,
|
||||||
model_output_size,
|
model_output_size,
|
||||||
training_opts.get("coverage_ratio", 0.625))
|
training_opts.get("coverage_ratio", 0.625),
|
||||||
|
config)
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def set_mask_class(self):
|
def set_mask_class(self):
|
||||||
|
|
@ -214,18 +215,15 @@ class TrainingDataGenerator():
|
||||||
|
|
||||||
class ImageManipulation():
|
class ImageManipulation():
|
||||||
""" Manipulations to be performed on training images """
|
""" Manipulations to be performed on training images """
|
||||||
def __init__(self, input_size, output_size, coverage_ratio):
|
def __init__(self, input_size, output_size, coverage_ratio, config):
|
||||||
""" input_size: Size of the face input into the model
|
""" input_size: Size of the face input into the model
|
||||||
output_size: Size of the face that comes out of the modell
|
output_size: Size of the face that comes out of the modell
|
||||||
coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160
|
coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160
|
||||||
"""
|
"""
|
||||||
logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s)",
|
logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s, "
|
||||||
self.__class__.__name__, input_size, output_size, coverage_ratio)
|
"config: %s)", self.__class__.__name__, input_size, output_size,
|
||||||
# Transform args
|
coverage_ratio, config)
|
||||||
self.rotation_range = 10 # Range to randomly rotate the image by
|
self.config = config
|
||||||
self.zoom_range = 0.05 # Range to randomly zoom the image by
|
|
||||||
self.shift_range = 0.05 # Range to randomly translate the image by
|
|
||||||
self.random_flip = 0.5 # Chance to flip the image horizontally
|
|
||||||
# Transform and Warp args
|
# Transform and Warp args
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
|
|
@ -246,28 +244,32 @@ class ImageManipulation():
|
||||||
img[:, :, :3] = face
|
img[:, :, :3] = face
|
||||||
return img.astype('float32') / 255.0
|
return img.astype('float32') / 255.0
|
||||||
|
|
||||||
@staticmethod
|
def random_clahe(self, image):
|
||||||
def random_clahe(image):
|
|
||||||
""" Randomly perform Contrast Limited Adaptive Histogram Equilization """
|
""" Randomly perform Contrast Limited Adaptive Histogram Equilization """
|
||||||
base_contrast = image.shape[0] // 128
|
|
||||||
contrast_random = random()
|
contrast_random = random()
|
||||||
if contrast_random <= 0.5:
|
if contrast_random > self.config.get("color_clahe_chance", 50) / 100:
|
||||||
contrast_adjustment = int((contrast_random * 10.0) * (base_contrast / 2))
|
return image
|
||||||
grid_size = base_contrast + contrast_adjustment
|
|
||||||
logger.trace("Adjusting Contrast. Grid Size: %s", grid_size)
|
|
||||||
|
|
||||||
clahe = cv2.createCLAHE(clipLimit=2.0, # pylint: disable=no-member
|
base_contrast = image.shape[0] // 128
|
||||||
tileGridSize=(grid_size, grid_size))
|
grid_base = random() * self.config.get("color_clahe_max_size", 4)
|
||||||
for chan in range(3):
|
contrast_adjustment = int(grid_base * (base_contrast / 2))
|
||||||
image[:, :, chan] = clahe.apply(image[:, :, chan])
|
grid_size = base_contrast + contrast_adjustment
|
||||||
|
logger.trace("Adjusting Contrast. Grid Size: %s", grid_size)
|
||||||
|
|
||||||
|
clahe = cv2.createCLAHE(clipLimit=2.0, # pylint: disable=no-member
|
||||||
|
tileGridSize=(grid_size, grid_size))
|
||||||
|
for chan in range(3):
|
||||||
|
image[:, :, chan] = clahe.apply(image[:, :, chan])
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@staticmethod
|
def random_lab(self, image):
|
||||||
def random_lab(image):
|
""" Perform random color/lightness adjustment in L*a*b* colorspace """
|
||||||
""" Perform random color/lightness adjustment in LAB colorspace """
|
amount_l = self.config.get("color_lightness", 30) / 100
|
||||||
randoms = [(random() * 0.6) - 0.3, # L adjust +/- 30%
|
amount_ab = self.config.get("color_ab", 8) / 100
|
||||||
(random() * 0.16) - 0.08, # A adjust +/- 8%
|
|
||||||
(random() * 0.16) - 0.08] # B adjust +/- 8%
|
randoms = [(random() * amount_l * 2) - amount_l, # L adjust
|
||||||
|
(random() * amount_ab * 2) - amount_ab, # A adjust
|
||||||
|
(random() * amount_ab * 2) - amount_ab] # B adjust
|
||||||
|
|
||||||
logger.trace("Random LAB adjustments: %s", randoms)
|
logger.trace("Random LAB adjustments: %s", randoms)
|
||||||
image = cv2.cvtColor( # pylint:disable=no-member
|
image = cv2.cvtColor( # pylint:disable=no-member
|
||||||
|
|
@ -305,10 +307,15 @@ class ImageManipulation():
|
||||||
logger.trace("Randomly transforming image")
|
logger.trace("Randomly transforming image")
|
||||||
height, width = image.shape[0:2]
|
height, width = image.shape[0:2]
|
||||||
|
|
||||||
rotation = np.random.uniform(-self.rotation_range, self.rotation_range)
|
rotation_range = self.config.get("rotation_range", 10)
|
||||||
scale = np.random.uniform(1 - self.zoom_range, 1 + self.zoom_range)
|
rotation = np.random.uniform(-rotation_range, rotation_range)
|
||||||
tnx = np.random.uniform(-self.shift_range, self.shift_range) * width
|
|
||||||
tny = np.random.uniform(-self.shift_range, self.shift_range) * height
|
zoom_range = self.config.get("zoom_range", 5) / 100
|
||||||
|
scale = np.random.uniform(1 - zoom_range, 1 + zoom_range)
|
||||||
|
|
||||||
|
shift_range = self.config.get("shift_range", 5) / 100
|
||||||
|
tnx = np.random.uniform(-shift_range, shift_range) * width
|
||||||
|
tny = np.random.uniform(-shift_range, shift_range) * height
|
||||||
|
|
||||||
mat = cv2.getRotationMatrix2D( # pylint:disable=no-member
|
mat = cv2.getRotationMatrix2D( # pylint:disable=no-member
|
||||||
(width // 2, height // 2), rotation, scale)
|
(width // 2, height // 2), rotation, scale)
|
||||||
|
|
@ -323,7 +330,8 @@ class ImageManipulation():
|
||||||
def do_random_flip(self, image):
|
def do_random_flip(self, image):
|
||||||
""" Perform flip on image if random number is within threshold """
|
""" Perform flip on image if random number is within threshold """
|
||||||
logger.trace("Randomly flipping image")
|
logger.trace("Randomly flipping image")
|
||||||
if np.random.random() < self.random_flip:
|
random_flip = self.config.get("random_flip", 50) / 100
|
||||||
|
if np.random.random() < random_flip:
|
||||||
logger.trace("Flip within threshold. Flipping")
|
logger.trace("Flip within threshold. Flipping")
|
||||||
retval = image[:, ::-1]
|
retval = image[:, ::-1]
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -74,10 +74,6 @@ class Config(FaceswapConfig):
|
||||||
info="If using a mask, This penalizes the loss for the masked area, to give higher "
|
info="If using a mask, This penalizes the loss for the masked area, to give higher "
|
||||||
"priority to the face area. \nShould increase overall quality and speed up "
|
"priority to the face area. \nShould increase overall quality and speed up "
|
||||||
"training. This should probably be left at True")
|
"training. This should probably be left at True")
|
||||||
self.add_item(
|
|
||||||
section=section, title="preview_images", datatype=int, default=14, min_max=(2, 16),
|
|
||||||
rounding=2, fixed=False,
|
|
||||||
info="Number of sample faces to display for each side in the preview when training.")
|
|
||||||
logger.debug("Set global config")
|
logger.debug("Set global config")
|
||||||
|
|
||||||
def load_module(self, filename, module_path, plugin_type):
|
def load_module(self, filename, module_path, plugin_type):
|
||||||
|
|
@ -88,7 +84,8 @@ class Config(FaceswapConfig):
|
||||||
section = ".".join((plugin_type, module.replace("_defaults", "")))
|
section = ".".join((plugin_type, module.replace("_defaults", "")))
|
||||||
logger.debug("Importing defaults module: %s.%s", module_path, module)
|
logger.debug("Importing defaults module: %s.%s", module_path, module)
|
||||||
mod = import_module("{}.{}".format(module_path, module))
|
mod = import_module("{}.{}".format(module_path, module))
|
||||||
helptext = mod._HELPTEXT + ADDITIONAL_INFO # pylint:disable=protected-access
|
helptext = mod._HELPTEXT # pylint:disable=protected-access
|
||||||
|
helptext += ADDITIONAL_INFO if module_path.endswith("model") else ""
|
||||||
self.add_section(title=section, info=helptext)
|
self.add_section(title=section, info=helptext)
|
||||||
for key, val in mod._DEFAULTS.items(): # pylint:disable=protected-access
|
for key, val in mod._DEFAULTS.items(): # pylint:disable=protected-access
|
||||||
self.add_item(section=section, title=key, **val)
|
self.add_item(section=section, title=key, **val)
|
||||||
|
|
|
||||||
|
|
@ -152,9 +152,6 @@ class ModelBase():
|
||||||
super() this method for defaults otherwise be sure to add """
|
super() this method for defaults otherwise be sure to add """
|
||||||
logger.debug("Setting training data")
|
logger.debug("Setting training data")
|
||||||
# Force number of preview images to between 2 and 16
|
# Force number of preview images to between 2 and 16
|
||||||
preview_images = self.config.get("preview_images", 14)
|
|
||||||
preview_images = min(max(preview_images, 2), 16)
|
|
||||||
self.training_opts["preview_images"] = preview_images
|
|
||||||
self.training_opts["training_size"] = self.state.training_size
|
self.training_opts["training_size"] = self.state.training_size
|
||||||
self.training_opts["no_logs"] = self.state.current_session["no_logs"]
|
self.training_opts["no_logs"] = self.state.current_session["no_logs"]
|
||||||
self.training_opts["mask_type"] = self.config.get("mask_type", None)
|
self.training_opts["mask_type"] = self.config.get("mask_type", None)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
Set to None for not used
|
Set to None for not used
|
||||||
no_logs: Disable tensorboard logging
|
no_logs: Disable tensorboard logging
|
||||||
warp_to_landmarks: Use random_warp_landmarks instead of random_warp
|
warp_to_landmarks: Use random_warp_landmarks instead of random_warp
|
||||||
augment_color: Perform random shifting of LAB colors
|
augment_color: Perform random shifting of L*a*b* colors
|
||||||
no_flip: Don't perform a random flip on the image
|
no_flip: Don't perform a random flip on the image
|
||||||
pingpong: Train each side seperately per save iteration rather than together
|
pingpong: Train each side seperately per save iteration rather than together
|
||||||
"""
|
"""
|
||||||
|
|
@ -33,16 +33,23 @@ from lib.alignments import Alignments
|
||||||
from lib.faces_detect import DetectedFace
|
from lib.faces_detect import DetectedFace
|
||||||
from lib.training_data import TrainingDataGenerator, stack_images
|
from lib.training_data import TrainingDataGenerator, stack_images
|
||||||
from lib.utils import get_folder, get_image_paths
|
from lib.utils import get_folder, get_image_paths
|
||||||
|
from plugins.train._config import Config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(plugin_name, configfile=None):
|
||||||
|
""" Return the config for the requested model """
|
||||||
|
return Config(plugin_name, configfile=configfile).config_dict
|
||||||
|
|
||||||
|
|
||||||
class TrainerBase():
|
class TrainerBase():
|
||||||
""" Base Trainer """
|
""" Base Trainer """
|
||||||
|
|
||||||
def __init__(self, model, images, batch_size):
|
def __init__(self, model, images, batch_size, configfile):
|
||||||
logger.debug("Initializing %s: (model: '%s', batch_size: %s)",
|
logger.debug("Initializing %s: (model: '%s', batch_size: %s)",
|
||||||
self.__class__.__name__, model, batch_size)
|
self.__class__.__name__, model, batch_size)
|
||||||
|
self.config = get_config(".".join(self.__module__.split(".")[-2:]), configfile=configfile)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model.state.add_session_batchsize(batch_size)
|
self.model.state.add_session_batchsize(batch_size)
|
||||||
|
|
@ -56,7 +63,8 @@ class TrainerBase():
|
||||||
images[side],
|
images[side],
|
||||||
self.model,
|
self.model,
|
||||||
self.use_mask,
|
self.use_mask,
|
||||||
batch_size)
|
batch_size,
|
||||||
|
self.config)
|
||||||
for side in self.sides}
|
for side in self.sides}
|
||||||
|
|
||||||
self.tensorboard = self.set_tensorboard()
|
self.tensorboard = self.set_tensorboard()
|
||||||
|
|
@ -67,6 +75,7 @@ class TrainerBase():
|
||||||
self.timelapse = Timelapse(self.model,
|
self.timelapse = Timelapse(self.model,
|
||||||
self.use_mask,
|
self.use_mask,
|
||||||
self.model.training_opts["coverage_ratio"],
|
self.model.training_opts["coverage_ratio"],
|
||||||
|
self.config.get("preview_images", 14),
|
||||||
self.batchers)
|
self.batchers)
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
|
@ -218,13 +227,14 @@ class TrainerBase():
|
||||||
|
|
||||||
class Batcher():
|
class Batcher():
|
||||||
""" Batch images from a single side """
|
""" Batch images from a single side """
|
||||||
def __init__(self, side, images, model, use_mask, batch_size):
|
def __init__(self, side, images, model, use_mask, batch_size, config):
|
||||||
logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s)",
|
logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s, config: %s)",
|
||||||
self.__class__.__name__, side, len(images), batch_size)
|
self.__class__.__name__, side, len(images), batch_size, config)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.use_mask = use_mask
|
self.use_mask = use_mask
|
||||||
self.side = side
|
self.side = side
|
||||||
self.images = images
|
self.images = images
|
||||||
|
self.config = config
|
||||||
self.target = None
|
self.target = None
|
||||||
self.samples = None
|
self.samples = None
|
||||||
self.mask = None
|
self.mask = None
|
||||||
|
|
@ -239,7 +249,10 @@ class Batcher():
|
||||||
input_size = self.model.input_shape[0]
|
input_size = self.model.input_shape[0]
|
||||||
output_size = self.model.output_shape[0]
|
output_size = self.model.output_shape[0]
|
||||||
logger.debug("input_size: %s, output_size: %s", input_size, output_size)
|
logger.debug("input_size: %s, output_size: %s", input_size, output_size)
|
||||||
generator = TrainingDataGenerator(input_size, output_size, self.model.training_opts)
|
generator = TrainingDataGenerator(input_size,
|
||||||
|
output_size,
|
||||||
|
self.model.training_opts,
|
||||||
|
self.config)
|
||||||
return generator
|
return generator
|
||||||
|
|
||||||
def train_one_batch(self, do_preview):
|
def train_one_batch(self, do_preview):
|
||||||
|
|
@ -289,7 +302,9 @@ class Batcher():
|
||||||
def set_preview_feed(self):
|
def set_preview_feed(self):
|
||||||
""" Set the preview dictionary """
|
""" Set the preview dictionary """
|
||||||
logger.debug("Setting preview feed: (side: '%s')", self.side)
|
logger.debug("Setting preview feed: (side: '%s')", self.side)
|
||||||
batchsize = min(len(self.images), self.model.training_opts.get("preview_images", 14))
|
preview_images = self.config.get("preview_images", 14)
|
||||||
|
preview_images = min(max(preview_images, 2), 16)
|
||||||
|
batchsize = min(len(self.images), preview_images)
|
||||||
self.preview_feed = self.load_generator().minibatch_ab(self.images,
|
self.preview_feed = self.load_generator().minibatch_ab(self.images,
|
||||||
batchsize,
|
batchsize,
|
||||||
self.side,
|
self.side,
|
||||||
|
|
@ -299,7 +314,7 @@ class Batcher():
|
||||||
|
|
||||||
def compile_sample(self, batch_size, samples=None, images=None):
|
def compile_sample(self, batch_size, samples=None, images=None):
|
||||||
""" Training samples to display in the viewer """
|
""" Training samples to display in the viewer """
|
||||||
num_images = self.model.training_opts.get("preview_images", 14)
|
num_images = self.config.get("preview_images", 14)
|
||||||
num_images = min(batch_size, num_images) if batch_size is not None else num_images
|
num_images = min(batch_size, num_images) if batch_size is not None else num_images
|
||||||
logger.debug("Compiling samples: (side: '%s', samples: %s)", self.side, num_images)
|
logger.debug("Compiling samples: (side: '%s', samples: %s)", self.side, num_images)
|
||||||
images = images if images is not None else self.target
|
images = images if images is not None else self.target
|
||||||
|
|
@ -561,10 +576,11 @@ class Samples():
|
||||||
|
|
||||||
class Timelapse():
|
class Timelapse():
|
||||||
""" Create the timelapse """
|
""" Create the timelapse """
|
||||||
def __init__(self, model, use_mask, coverage_ratio, batchers):
|
def __init__(self, model, use_mask, coverage_ratio, preview_images, batchers):
|
||||||
logger.debug("Initializing %s: model: %s, use_mask: %s, coverage_ratio: %s, "
|
logger.debug("Initializing %s: model: %s, use_mask: %s, coverage_ratio: %s, "
|
||||||
"batchers: '%s')", self.__class__.__name__, model, use_mask,
|
"preview_images: %s, batchers: '%s')", self.__class__.__name__, model,
|
||||||
coverage_ratio, batchers)
|
use_mask, coverage_ratio, preview_images, batchers)
|
||||||
|
self.preview_images = preview_images
|
||||||
self.samples = Samples(model, use_mask, coverage_ratio)
|
self.samples = Samples(model, use_mask, coverage_ratio)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.batchers = batchers
|
self.batchers = batchers
|
||||||
|
|
@ -591,7 +607,7 @@ class Timelapse():
|
||||||
images = {"a": get_image_paths(input_a), "b": get_image_paths(input_b)}
|
images = {"a": get_image_paths(input_a), "b": get_image_paths(input_b)}
|
||||||
batchsize = min(len(images["a"]),
|
batchsize = min(len(images["a"]),
|
||||||
len(images["b"]),
|
len(images["b"]),
|
||||||
self.model.training_opts.get("preview_images", 14))
|
self.preview_images)
|
||||||
for side, image_files in images.items():
|
for side, image_files in images.items():
|
||||||
self.batchers[side].set_timelapse_feed(image_files, batchsize)
|
self.batchers[side].set_timelapse_feed(image_files, batchsize)
|
||||||
logger.debug("Set up timelapse")
|
logger.debug("Set up timelapse")
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,10 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
""" Original Trainer """
|
""" Original Trainer """
|
||||||
|
|
||||||
from ._base import TrainerBase as Trainer
|
from ._base import TrainerBase
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer(TrainerBase):
|
||||||
|
""" Original is currently identical to Base """
|
||||||
|
def __init__(self, *args, **kwargs): # pylint:disable=useless-super-delegation
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
|
||||||
125
plugins/train/trainer/original_defaults.py
Executable file
125
plugins/train/trainer/original_defaults.py
Executable file
|
|
@ -0,0 +1,125 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
The default options for the faceswap Original Model plugin.
|
||||||
|
|
||||||
|
Defaults files should be named <plugin_name>_defaults.py
|
||||||
|
Any items placed into this file will automatically get added to the relevant config .ini files
|
||||||
|
within the faceswap/config folder.
|
||||||
|
|
||||||
|
The following variables should be defined:
|
||||||
|
_HELPTEXT: A string describing what this plugin does
|
||||||
|
_DEFAULTS: A dictionary containing the options, defaults and meta information. The
|
||||||
|
dictionary should be defined as:
|
||||||
|
{<option_name>: {<metadata>}}
|
||||||
|
|
||||||
|
<option_name> should always be lower text.
|
||||||
|
<metadata> dictionary requirements are listed below.
|
||||||
|
|
||||||
|
The following keys are expected for the _DEFAULTS <metadata> dict:
|
||||||
|
datatype: [required] A python type class. This limits the type of data that can be
|
||||||
|
provided in the .ini file and ensures that the value is returned in the
|
||||||
|
correct type to faceswap. Valid datatypes are: <class 'int'>, <class 'float'>,
|
||||||
|
<class 'str'>, <class 'bool'>.
|
||||||
|
default: [required] The default value for this option.
|
||||||
|
info: [required] A string describing what this option does.
|
||||||
|
choices: [optional] If this option's datatype is of <class 'str'> then valid
|
||||||
|
selections can be defined here. This validates the option and also enables
|
||||||
|
a combobox / radio option in the GUI.
|
||||||
|
gui_radio: [optional] If <choices> are defined, this indicates that the GUI should use
|
||||||
|
radio buttons rather than a combobox to display this option.
|
||||||
|
min_max: [partial] For <class 'int'> and <class 'float'> datatypes this is required
|
||||||
|
otherwise it is ignored. Should be a tuple of min and max accepted values.
|
||||||
|
This is used for controlling the GUI slider range. Values are not enforced.
|
||||||
|
rounding: [partial] For <class 'int'> and <class 'float'> datatypes this is
|
||||||
|
required otherwise it is ignored. Used for the GUI slider. For floats, this
|
||||||
|
is the number of decimal places to display. For ints this is the step size.
|
||||||
|
fixed: [optional] [train only]. Training configurations are fixed when the model is
|
||||||
|
created, and then reloaded from the state file. Marking an item as fixed=False
|
||||||
|
indicates that this value can be changed for existing models, and will override
|
||||||
|
the value saved in the state file with the updated value in config. If not
|
||||||
|
provided this will default to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
_HELPTEXT = ("Original Trainer Options.\n"
|
||||||
|
"WARNING: The defaults for augmentation will be fine for 99.9% of use cases. "
|
||||||
|
"Only change them if you absolutely know what you are doing!")
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULTS = {
|
||||||
|
"preview_images": {
|
||||||
|
"default": 14,
|
||||||
|
"info": "Number of sample faces to display for each side in the preview when training.",
|
||||||
|
"datatype": int,
|
||||||
|
"rounding": 2,
|
||||||
|
"min_max": (2, 16),
|
||||||
|
},
|
||||||
|
"zoom_amount": {
|
||||||
|
"default": 5,
|
||||||
|
"info": "Percentage amount to randomly zoom each training image in and out.",
|
||||||
|
"datatype": int,
|
||||||
|
"rounding": 1,
|
||||||
|
"min_max": (0, 25),
|
||||||
|
},
|
||||||
|
"rotation_range": {
|
||||||
|
"default": 10,
|
||||||
|
"info": "Percentage amount to randomly rotate each training image.",
|
||||||
|
"datatype": int,
|
||||||
|
"rounding": 1,
|
||||||
|
"min_max": (0, 25),
|
||||||
|
},
|
||||||
|
"shift_range": {
|
||||||
|
"default": 5,
|
||||||
|
"info": "Percentage amount to randomly shift each training image horizontally and "
|
||||||
|
"vertically.",
|
||||||
|
"datatype": int,
|
||||||
|
"rounding": 1,
|
||||||
|
"min_max": (0, 0),
|
||||||
|
},
|
||||||
|
"flip_chance": {
|
||||||
|
"default": 50,
|
||||||
|
"info": "Percentage chance to randomly flip each training image horizontally.\n"
|
||||||
|
"NB: This is ignored if the 'no-flip' option is enabled",
|
||||||
|
"datatype": int,
|
||||||
|
"rounding": 1,
|
||||||
|
"min_max": (0, 75),
|
||||||
|
},
|
||||||
|
"color_lightness": {
|
||||||
|
"default": 30,
|
||||||
|
"info": "Percentage amount to randomly alter the lightness of each training image.\n"
|
||||||
|
"NB: This is ignored if the 'no-augment-color' option is enabled",
|
||||||
|
"datatype": int,
|
||||||
|
"rounding": 1,
|
||||||
|
"min_max": (0, 75),
|
||||||
|
},
|
||||||
|
"color_ab": {
|
||||||
|
"default": 8,
|
||||||
|
"info": "Percentage amount to randomly alter the 'a' and 'b' colors of the L*a*b* color "
|
||||||
|
"space of each training image.\n"
|
||||||
|
"NB: This is ignored if the 'no-augment-color' option is enabled",
|
||||||
|
"datatype": int,
|
||||||
|
"rounding": 1,
|
||||||
|
"min_max": (0, 50),
|
||||||
|
},
|
||||||
|
"color_clahe_chance": {
|
||||||
|
"default": 50,
|
||||||
|
"info": "Percentage chance to perform Contrast Limited Adaptive Histogram Equalization on "
|
||||||
|
"each training image.\n"
|
||||||
|
"NB: This is ignored if the 'no-augment-color' option is enabled",
|
||||||
|
"datatype": int,
|
||||||
|
"rounding": 1,
|
||||||
|
"min_max": (0, 75),
|
||||||
|
"fixed": False,
|
||||||
|
},
|
||||||
|
"color_clahe_max_size": {
|
||||||
|
"default": 4,
|
||||||
|
"info": "The grid size dictates how much Contrast Limited Adaptive Histogram Equalization "
|
||||||
|
"is performed on any training image selected for clahe. Contrast will be applied "
|
||||||
|
"randomly with a gridsize of 0 up to the maximum. This value is a multiplier "
|
||||||
|
"calculated from the training image size.\n"
|
||||||
|
"NB: This is ignored if the 'no-augment-color' option is enabled",
|
||||||
|
"datatype": int,
|
||||||
|
"rounding": 1,
|
||||||
|
"min_max": (1, 8),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -198,7 +198,8 @@ class Train():
|
||||||
trainer = PluginLoader.get_trainer(model.trainer)
|
trainer = PluginLoader.get_trainer(model.trainer)
|
||||||
trainer = trainer(model,
|
trainer = trainer(model,
|
||||||
self.images,
|
self.images,
|
||||||
self.args.batch_size)
|
self.args.batch_size,
|
||||||
|
self.args.configfile)
|
||||||
logger.debug("Loaded Trainer")
|
logger.debug("Loaded Trainer")
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user