- Bugfix: Training - Disable loss multipliers if penalized loss not selected

- Training: Half mouth/eye multiplier defaults
- GUI: Remove analysis callback from convert tab
This commit is contained in:
torzdf 2020-09-24 11:31:54 +01:00
parent f3227b7b62
commit 961f8ff283
5 changed files with 26 additions and 17 deletions

View File

@ -293,7 +293,7 @@ class ControlPanelOption():
var.trace("w",
lambda name, index, mode, cmd=self._command: self._modified_callback(cmd))
if track_modified and self._command in ("train", "convert") and self.title == "Model Dir":
if track_modified and self._command == "train" and self.title == "Model Dir":
var.trace("w", lambda name, index, mode, v=var: self._model_callback(v))
return var

View File

@ -283,19 +283,19 @@ class TrainingDataGenerator(): # pylint:disable=too-few-public-methods
item = self._masks[key]
if item is None and key != "masks":
continue
if item is None and key == "masks":
logger.trace("Creating dummy masks. side: %s", side)
masks = np.ones_like(batch[..., :1], dtype=batch.dtype)
continue
# Expand out partials for eye and mouth masks on first epoch
if item is not None and key in ("eyes", "mouths"):
self._expand_partials(side, item, filenames)
logger.trace("Obtaining masks for batch. (key: %s side: %s)", key, side)
masks = np.array([self._get_mask(item[side][filename], size)
for filename in filenames], dtype=batch.dtype)
masks = self._resize_masks(size, masks)
if item is None and key == "masks":
logger.trace("Creating dummy masks. side: %s", side)
masks = np.ones_like(batch[..., :1], dtype=batch.dtype)
else:
logger.trace("Obtaining masks for batch. (key: %s side: %s)", key, side)
masks = np.array([self._get_mask(item[side][filename], size)
for filename in filenames], dtype=batch.dtype)
masks = self._resize_masks(size, masks)
logger.trace("masks: (key: %s, shape: %s)", key, masks.shape)
batch = np.concatenate((batch, masks), axis=-1)

View File

@ -271,7 +271,7 @@ class Config(FaceswapConfig):
group="loss",
min_max=(1, 40),
rounding=1,
default=6,
default=3,
fixed=False,
info="The amount of priority to give to the eyes.\n\nThe value given here is as a "
"multiplier of the main loss score. For example:"
@ -286,7 +286,7 @@ class Config(FaceswapConfig):
group="loss",
min_max=(1, 40),
rounding=1,
default=4,
default=2,
fixed=False,
info="The amount of priority to give to the mouth.\n\nThe value given here is as a "
"multiplier of the main loss score. For Example:"

View File

@ -1044,9 +1044,17 @@ class _Loss():
list:
A list of channel indices that contain the mask for the corresponding config item
"""
eye_multiplier = self._config["eye_multiplier"]
mouth_multiplier = self._config["mouth_multiplier"]
if not self._config["penalized_mask_loss"] and (eye_multiplier > 1 or
mouth_multiplier > 1):
logger.warning("You have selected eye/mouth loss multipliers greate than 1x, but "
"Penalized Mask Loss is disabled. Disabling all multipliers.")
eye_multiplier = 1
mouth_multiplier = 1
uses_masks = (self._config["penalized_mask_loss"],
self._config["eye_multiplier"] > 1,
self._config["mouth_multiplier"] > 1)
eye_multiplier > 1,
mouth_multiplier > 1)
mask_channels = [-1 for _ in range(len(uses_masks))]
current_channel = 3
for idx, mask_required in enumerate(uses_masks):

View File

@ -107,9 +107,10 @@ class TrainerBase():
are required for training, `masks_eye` if eye masks are required and `masks_mouth` if
mouth masks are required. """
retval = dict()
penalized_loss = self._model.config["penalized_mask_loss"]
if not any([self._model.config["learn_mask"],
self._model.config["penalized_mask_loss"],
penalized_loss,
self._model.config["eye_multiplier"] > 1,
self._model.config["mouth_multiplier"] > 1,
self._model.command_line_arguments.warp_to_landmarks]):
@ -121,14 +122,14 @@ class TrainerBase():
logger.debug("Adding landmarks to training opts dict")
retval["landmarks"] = alignments.landmarks
if self._model.config["learn_mask"] or self._model.config["penalized_mask_loss"]:
if self._model.config["learn_mask"] or penalized_loss:
logger.debug("Adding masks to training opts dict")
retval["masks"] = alignments.masks
if self._model.config["eye_multiplier"] > 1:
if penalized_loss and self._model.config["eye_multiplier"] > 1:
retval["masks_eye"] = alignments.masks_eye
if self._model.config["mouth_multiplier"] > 1:
if penalized_loss and self._model.config["mouth_multiplier"] > 1:
retval["masks_mouth"] = alignments.masks_mouth
logger.debug({key: {k: len(v) for k, v in val.items()} for key, val in retval.items()})