mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
- Bugfix: Training - Disable loss multipliers if penalized loss not selected
- Training: Half mouth/eye multiplier defaults - GUI: Remove analysis callback from convert tab
This commit is contained in:
parent
f3227b7b62
commit
961f8ff283
|
|
@ -293,7 +293,7 @@ class ControlPanelOption():
|
|||
var.trace("w",
|
||||
lambda name, index, mode, cmd=self._command: self._modified_callback(cmd))
|
||||
|
||||
if track_modified and self._command in ("train", "convert") and self.title == "Model Dir":
|
||||
if track_modified and self._command == "train" and self.title == "Model Dir":
|
||||
var.trace("w", lambda name, index, mode, v=var: self._model_callback(v))
|
||||
|
||||
return var
|
||||
|
|
|
|||
|
|
@ -283,19 +283,19 @@ class TrainingDataGenerator(): # pylint:disable=too-few-public-methods
|
|||
item = self._masks[key]
|
||||
if item is None and key != "masks":
|
||||
continue
|
||||
if item is None and key == "masks":
|
||||
logger.trace("Creating dummy masks. side: %s", side)
|
||||
masks = np.ones_like(batch[..., :1], dtype=batch.dtype)
|
||||
continue
|
||||
|
||||
# Expand out partials for eye and mouth masks on first epoch
|
||||
if item is not None and key in ("eyes", "mouths"):
|
||||
self._expand_partials(side, item, filenames)
|
||||
|
||||
logger.trace("Obtaining masks for batch. (key: %s side: %s)", key, side)
|
||||
masks = np.array([self._get_mask(item[side][filename], size)
|
||||
for filename in filenames], dtype=batch.dtype)
|
||||
masks = self._resize_masks(size, masks)
|
||||
if item is None and key == "masks":
|
||||
logger.trace("Creating dummy masks. side: %s", side)
|
||||
masks = np.ones_like(batch[..., :1], dtype=batch.dtype)
|
||||
else:
|
||||
logger.trace("Obtaining masks for batch. (key: %s side: %s)", key, side)
|
||||
masks = np.array([self._get_mask(item[side][filename], size)
|
||||
for filename in filenames], dtype=batch.dtype)
|
||||
masks = self._resize_masks(size, masks)
|
||||
|
||||
logger.trace("masks: (key: %s, shape: %s)", key, masks.shape)
|
||||
batch = np.concatenate((batch, masks), axis=-1)
|
||||
|
|
|
|||
|
|
@ -271,7 +271,7 @@ class Config(FaceswapConfig):
|
|||
group="loss",
|
||||
min_max=(1, 40),
|
||||
rounding=1,
|
||||
default=6,
|
||||
default=3,
|
||||
fixed=False,
|
||||
info="The amount of priority to give to the eyes.\n\nThe value given here is as a "
|
||||
"multiplier of the main loss score. For example:"
|
||||
|
|
@ -286,7 +286,7 @@ class Config(FaceswapConfig):
|
|||
group="loss",
|
||||
min_max=(1, 40),
|
||||
rounding=1,
|
||||
default=4,
|
||||
default=2,
|
||||
fixed=False,
|
||||
info="The amount of priority to give to the mouth.\n\nThe value given here is as a "
|
||||
"multiplier of the main loss score. For Example:"
|
||||
|
|
|
|||
|
|
@ -1044,9 +1044,17 @@ class _Loss():
|
|||
list:
|
||||
A list of channel indices that contain the mask for the corresponding config item
|
||||
"""
|
||||
eye_multiplier = self._config["eye_multiplier"]
|
||||
mouth_multiplier = self._config["mouth_multiplier"]
|
||||
if not self._config["penalized_mask_loss"] and (eye_multiplier > 1 or
|
||||
mouth_multiplier > 1):
|
||||
logger.warning("You have selected eye/mouth loss multipliers greate than 1x, but "
|
||||
"Penalized Mask Loss is disabled. Disabling all multipliers.")
|
||||
eye_multiplier = 1
|
||||
mouth_multiplier = 1
|
||||
uses_masks = (self._config["penalized_mask_loss"],
|
||||
self._config["eye_multiplier"] > 1,
|
||||
self._config["mouth_multiplier"] > 1)
|
||||
eye_multiplier > 1,
|
||||
mouth_multiplier > 1)
|
||||
mask_channels = [-1 for _ in range(len(uses_masks))]
|
||||
current_channel = 3
|
||||
for idx, mask_required in enumerate(uses_masks):
|
||||
|
|
|
|||
|
|
@ -107,9 +107,10 @@ class TrainerBase():
|
|||
are required for training, `masks_eye` if eye masks are required and `masks_mouth` if
|
||||
mouth masks are required. """
|
||||
retval = dict()
|
||||
penalized_loss = self._model.config["penalized_mask_loss"]
|
||||
|
||||
if not any([self._model.config["learn_mask"],
|
||||
self._model.config["penalized_mask_loss"],
|
||||
penalized_loss,
|
||||
self._model.config["eye_multiplier"] > 1,
|
||||
self._model.config["mouth_multiplier"] > 1,
|
||||
self._model.command_line_arguments.warp_to_landmarks]):
|
||||
|
|
@ -121,14 +122,14 @@ class TrainerBase():
|
|||
logger.debug("Adding landmarks to training opts dict")
|
||||
retval["landmarks"] = alignments.landmarks
|
||||
|
||||
if self._model.config["learn_mask"] or self._model.config["penalized_mask_loss"]:
|
||||
if self._model.config["learn_mask"] or penalized_loss:
|
||||
logger.debug("Adding masks to training opts dict")
|
||||
retval["masks"] = alignments.masks
|
||||
|
||||
if self._model.config["eye_multiplier"] > 1:
|
||||
if penalized_loss and self._model.config["eye_multiplier"] > 1:
|
||||
retval["masks_eye"] = alignments.masks_eye
|
||||
|
||||
if self._model.config["mouth_multiplier"] > 1:
|
||||
if penalized_loss and self._model.config["mouth_multiplier"] > 1:
|
||||
retval["masks_mouth"] = alignments.masks_mouth
|
||||
|
||||
logger.debug({key: {k: len(v) for k, v in val.items()} for key, val in retval.items()})
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user