mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
lib.model.losses_tf - Add multiplier bug catching code
This commit is contained in:
parent
8ec4b4fb82
commit
51705fadb0
|
|
@ -581,15 +581,26 @@ class LossWrapper(tf.keras.losses.Loss):
|
|||
tuple
|
||||
(n_true, n_pred): The ground truth and predicted value tensors with the mask applied
|
||||
"""
|
||||
if mask_channel == -1:
|
||||
logger.debug("No mask to apply")
|
||||
return y_true[..., :3], y_pred[..., :3]
|
||||
try:
|
||||
if mask_channel == -1:
|
||||
logger.debug("No mask to apply")
|
||||
return y_true[..., :3], y_pred[..., :3]
|
||||
|
||||
logger.debug("Applying mask from channel %s", mask_channel)
|
||||
mask = K.expand_dims(y_true[..., mask_channel], axis=-1)
|
||||
mask_as_k_inv_prop = 1 - mask_prop
|
||||
mask = (mask * mask_prop) + mask_as_k_inv_prop
|
||||
|
||||
n_true = K.concatenate([y_true[:, :, :, i:i+1] * mask for i in range(3)], axis=-1)
|
||||
n_pred = K.concatenate([y_pred[:, :, :, i:i+1] * mask for i in range(3)], axis=-1)
|
||||
|
||||
except:
|
||||
logger.error("You have hit a bug which is being actively tracked by the developer.")
|
||||
logger.error("Please provide the following information so it can be fixed.")
|
||||
logger.error("y_true: %s, %s, y_pred: %s, %s, mask_channel: %s, mask_prop: %s",
|
||||
K.int_shape(y_true), K.dtype(y_true), K.int_shape(y_pred),
|
||||
K.dtype(y_pred), mask_channel, mask_prop)
|
||||
raise
|
||||
|
||||
logger.debug("Applying mask from channel %s", mask_channel)
|
||||
mask = K.expand_dims(y_true[..., mask_channel], axis=-1)
|
||||
mask_as_k_inv_prop = 1 - mask_prop
|
||||
mask = (mask * mask_prop) + mask_as_k_inv_prop
|
||||
|
||||
n_true = K.concatenate([y_true[:, :, :, i:i+1] * mask for i in range(3)], axis=-1)
|
||||
n_pred = K.concatenate([y_pred[:, :, :, i:i+1] * mask for i in range(3)], axis=-1)
|
||||
return n_true, n_pred
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user