lib.model.losses_tf - Add multiplier bug catching code

This commit is contained in:
torzdf 2021-05-05 16:47:26 +01:00
parent 8ec4b4fb82
commit 51705fadb0

View File

@ -581,15 +581,26 @@ class LossWrapper(tf.keras.losses.Loss):
tuple
(n_true, n_pred): The ground truth and predicted value tensors with the mask applied
"""
if mask_channel == -1:
logger.debug("No mask to apply")
return y_true[..., :3], y_pred[..., :3]
try:
if mask_channel == -1:
logger.debug("No mask to apply")
return y_true[..., :3], y_pred[..., :3]
logger.debug("Applying mask from channel %s", mask_channel)
mask = K.expand_dims(y_true[..., mask_channel], axis=-1)
mask_as_k_inv_prop = 1 - mask_prop
mask = (mask * mask_prop) + mask_as_k_inv_prop
n_true = K.concatenate([y_true[:, :, :, i:i+1] * mask for i in range(3)], axis=-1)
n_pred = K.concatenate([y_pred[:, :, :, i:i+1] * mask for i in range(3)], axis=-1)
except:
logger.error("You have hit a bug which is being actively tracked by the developer.")
logger.error("Please provide the following information so it can be fixed.")
logger.error("y_true: %s, %s, y_pred: %s, %s, mask_channel: %s, mask_prop: %s",
K.int_shape(y_true), K.dtype(y_true), K.int_shape(y_pred),
K.dtype(y_pred), mask_channel, mask_prop)
raise
logger.debug("Applying mask from channel %s", mask_channel)
mask = K.expand_dims(y_true[..., mask_channel], axis=-1)
mask_as_k_inv_prop = 1 - mask_prop
mask = (mask * mask_prop) + mask_as_k_inv_prop
n_true = K.concatenate([y_true[:, :, :, i:i+1] * mask for i in range(3)], axis=-1)
n_pred = K.concatenate([y_pred[:, :, :, i:i+1] * mask for i in range(3)], axis=-1)
return n_true, n_pred