Update code to support Tensorflow versions up to 2.8 (#1213)

* Update maximum tf version in setup + requirements

* - bump max version of tf version in launcher
- standardise tf version check

* update keras get_custom_objects  for tf>2.6

* bugfix: force black text in GUI file dialogs (linux)

* dssim loss - Move to stock tf.ssim function

* Update optimizer imports for compatibility

* fix logging for tf2.8

* Fix GUI graphing for TF2.8

* update tests

* bump requirements.txt versions

* Remove limit on nvidia-ml-py

* Graphing bugfixes
  - Prevent live graph from displaying if data not yet available

* bugfix: Live graph. Collect loss labels correctly

* fix: live graph - swallow inconsistent loss errors

* Bugfix: Prevent live graph from clearing during training

* Fix graphing for AMD
This commit is contained in:
torzdf 2022-05-02 14:30:43 +01:00 committed by GitHub
parent cda49b3c3c
commit c1512fd41d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 551 additions and 504 deletions

View File

@ -1,19 +1,16 @@
tqdm>=4.62
tqdm>=4.64
psutil>=5.8.0
numpy>=1.18.0,<1.20.0
opencv-python>=4.5.3.0
pillow>=8.3.1
scikit-learn>=0.24.2
fastcluster>=1.1.26
numpy>=1.18.0
opencv-python>=4.5.5.0
pillow>=9.0.1
scikit-learn>=1.0.2
fastcluster>=1.2.4
# matplotlib 3.3.1 breaks custom toolbar in graph popup
matplotlib>=3.2.0,<3.3.0
imageio>=2.9.0
imageio-ffmpeg>=0.4.5
imageio-ffmpeg>=0.4.7
ffmpy==0.2.3
# Exclude badly numbered Python2 version of nvidia-ml-py
# nvidia-ml-py>=11.450,<300
# v11.515.0 changes dtype of output items. Pinned for now
# TODO update code to use latest version
nvidia-ml-py>=11.450,<11.515
nvidia-ml-py>=11.510,<300
pywin32>=228 ; sys_platform == "win32"
pynvx==1.0.0 ; sys_platform == "darwin"

View File

@ -9,8 +9,8 @@ from importlib import import_module
from lib.gpu_stats import set_exclude_devices, GPUStats
from lib.logger import crash_log, log_setup
from lib.utils import (FaceswapError, get_backend, KerasFinder, safe_shutdown, set_backend,
set_system_verbosity)
from lib.utils import (FaceswapError, get_backend, get_tf_version, KerasFinder, safe_shutdown,
set_backend, set_system_verbosity)
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -41,7 +41,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
self._test_for_tf_version()
self._test_for_gui()
cmd = os.path.basename(sys.argv[0])
src = "tools.{}".format(self._command.lower()) if cmd == "tools.py" else "scripts"
src = f"tools.{self._command.lower()}" if cmd == "tools.py" else "scripts"
mod = ".".join((src, self._command.lower()))
module = import_module(mod)
script = getattr(module, self._command.title())
@ -53,15 +53,15 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
Raises
------
FaceswapError
If Tensorflow is not found, or is not between versions 2.2 and 2.6
If Tensorflow is not found, or is not between versions 2.2 and 2.8
"""
min_ver = 2.2
max_ver = 2.6
max_ver = 2.8
try:
# Ensure tensorflow doesn't pin all threads to one core when using Math Kernel Library
os.environ["TF_MIN_GPU_MULTIPROCESSOR_COUNT"] = "4"
os.environ["KMP_AFFINITY"] = "disabled"
import tensorflow as tf # pylint:disable=import-outside-toplevel
import tensorflow as tf # noqa pylint:disable=import-outside-toplevel,unused-import
except ImportError as err:
if "DLL load failed while importing" in str(err):
msg = (
@ -77,14 +77,14 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
f"error: {str(err)}")
self._handle_import_error(msg)
tf_ver = float(".".join(tf.__version__.split(".")[:2])) # pylint:disable=no-member
tf_ver = get_tf_version()
if tf_ver < min_ver:
msg = ("The minimum supported Tensorflow is version {} but you have version {} "
"installed. Please upgrade Tensorflow.".format(min_ver, tf_ver))
msg = (f"The minimum supported Tensorflow is version {min_ver} but you have version "
f"{tf_ver} installed. Please upgrade Tensorflow.")
self._handle_import_error(msg)
if tf_ver > max_ver:
msg = ("The maximum supported Tensorflow is version {} but you have version {} "
"installed. Please downgrade Tensorflow.".format(max_ver, tf_ver))
msg = (f"The maximum supported Tensorflow is version {max_ver} but you have version "
f"{tf_ver} installed. Please downgrade Tensorflow.")
self._handle_import_error(msg)
logger.debug("Installed Tensorflow Version: %s", tf_ver)

View File

@ -7,10 +7,12 @@ import zlib
import numpy as np
import tensorflow as tf
from tensorflow.core.util import event_pb2
from tensorflow.python.framework import errors_impl as tf_errors
from tensorflow.core.util import event_pb2 # pylint:disable=no-name-in-module
from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
errors_impl as tf_errors)
from lib.serializer import get_serializer
from lib.utils import get_backend
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -43,7 +45,7 @@ class _LogFiles():
The full path of each log file for each training session id that has been run
"""
logger.debug("Loading log filenames. base_dir: '%s'", self._logs_folder)
retval = dict()
retval = {}
for dirpath, _, filenames in os.walk(self._logs_folder):
if not any(filename.startswith("events.out.tfevents") for filename in filenames):
continue
@ -133,7 +135,7 @@ class _Cache():
def __init__(self, session_ids):
logger.debug("Initializing: %s: (session_ids: %s)", self.__class__.__name__, session_ids)
self._data = {idx: None for idx in session_ids}
self._carry_over = dict()
self._carry_over = {}
self._loss_labels = []
logger.debug("Initialized: %s", self.__class__.__name__)
@ -158,18 +160,19 @@ class _Cache():
"""
logger.debug("Caching event data: (session_id: %s, labels: %s, data points: %s, "
"is_live: %s)", session_id, labels, len(data), is_live)
if not data:
logger.debug("No data to cache")
return
if labels:
logger.debug("Setting loss labels: %s", labels)
self._loss_labels = labels
if not data:
logger.debug("No data to cache")
return
timestamps, loss = self._to_numpy(data, is_live)
if not is_live or (is_live and not self._data.get(session_id, None)):
self._data[session_id] = dict(labels=labels,
self._data[session_id] = dict(labels=self._loss_labels,
loss=zlib.compress(loss),
loss_shape=loss.shape,
timestamps=zlib.compress(timestamps),
@ -207,10 +210,30 @@ class _Cache():
for idx in sorted(data)])
times, loss = self._process_data(data, times, loss, is_live)
times, loss = (np.array(times, dtype="float64"), np.array(loss, dtype="float32"))
if is_live and not all(len(val) == len(self._loss_labels) for val in loss):
# TODO Many attempts have been made to fix this for live graph logging, and the issue
# of non-consistent loss record sizes keeps coming up. In the meantime we shall swallow
# any loss values that are of incorrect length so graph remains functional. This will,
# most likely, lead to a mismatch on iteration count so a proper fix should be
# implemented.
# Timestamps and loss appears to remain consistent with each other, but sometimes loss
# appears non-consistent. eg (lengths):
# [2, 2, 2, 2, 2, 2, 2, 0] - last loss collection has zero length
# [1, 2, 2, 2, 2, 2, 2, 2] - 1st loss collection has 1 length
# [2, 2, 2, 3, 2, 2, 2] - 4th loss collection has 3 length
logger.debug("Inconsistent loss found in collection: %s", loss)
for idx in reversed(range(len(loss))):
if len(loss[idx]) != len(self._loss_labels):
logger.debug("Removing loss/timestamps at position %s", idx)
del loss[idx]
del times[idx]
times, loss = (np.array(times, dtype="float64"), np.array(loss, dtype="float32"))
logger.debug("Converted to numpy: (data points: %s, timestamps shape: %s, loss shape: %s)",
len(data), times.shape, loss.shape)
return times, loss
def _collect_carry_over(self, data):
@ -334,7 +357,7 @@ class _Cache():
dtype = "float32" if metric == "loss" else "float64"
retval = dict()
retval = {}
for idx, data in raw.items():
val = {metric: np.frombuffer(zlib.decompress(data[metric]),
dtype=dtype).reshape(data[f"{metric}_shape"])}
@ -461,7 +484,7 @@ class TensorBoardLogs():
and list of loss values for each step
"""
logger.debug("Getting loss: (session_id: %s)", session_id)
retval = dict()
retval = {}
for idx in [session_id] if session_id else self.session_ids:
self._check_cache(idx)
data = self._cache.get_data(idx, "loss")
@ -493,7 +516,7 @@ class TensorBoardLogs():
logger.debug("Getting timestamps: (session_id: %s, is_training: %s)",
session_id, self._is_training)
retval = dict()
retval = {}
for idx in [session_id] if session_id else self.session_ids:
self._check_cache(idx)
data = self._cache.get_data(idx, "timestamps")
@ -565,7 +588,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
session_id: int
The session id that the data is being cached for
"""
data = dict()
data = {}
try:
for record in self._iterator:
event = event_pb2.Event.FromString(record) # pylint:disable=no-member
@ -573,8 +596,11 @@ class _EventParser(): # pylint:disable=too-few-public-methods
continue
if event.summary.value[0].tag == "keras":
self._parse_outputs(event)
if get_backend() == "amd":
# No model is logged for AMD so need to get loss labels from state file
self._add_amd_loss_labels(session_id)
if event.summary.value[0].tag.startswith("batch_"):
data[event.step] = self._process_event(event, data.get(event.step, dict()))
data[event.step] = self._process_event(event, data.get(event.step, {}))
except tf_errors.DataLossError as err:
logger.warning("The logs for Session %s are corrupted and cannot be displayed. "
@ -605,10 +631,6 @@ class _EventParser(): # pylint:disable=too-few-public-methods
config = serializer.unmarshal(struct)["config"]
model_outputs = self._get_outputs(config)
# loss length of unique should be 3:
# - decoder_both, 1, 2
# - docoder_a, decoder_b, 1
split_output = len(np.unique(model_outputs[..., :2])) != 3
for side_outputs, side in zip(model_outputs, ("a", "b")):
logger.debug("side: '%s', outputs: '%s'", side, side_outputs)
layer_name = side_outputs[0][0]
@ -618,8 +640,10 @@ class _EventParser(): # pylint:disable=too-few-public-methods
layer_outputs = self._get_outputs(output_config)
for output in layer_outputs: # Drill into sub-model to get the actual output names
loss_name = output[0][0]
if not split_output: # Rename losses to reflect the side's output
loss_name = f"{loss_name.replace('_both', '')}_{side}"
if loss_name[-2:] not in ("_a", "_b"): # Rename losses to reflect the side output
new_name = f"{loss_name.replace('_both', '')}_{side}"
logger.debug("Renaming loss output from '%s' to '%s'", loss_name, new_name)
loss_name = new_name
if loss_name not in self._loss_labels:
logger.debug("Adding loss name: '%s'", loss_name)
self._loss_labels.append(loss_name)
@ -650,6 +674,28 @@ class _EventParser(): # pylint:disable=too-few-public-methods
outputs, outputs.shape)
return outputs
def _add_amd_loss_labels(self, session_id):
""" It is not possible to store the model config in the Tensorboard logs for AMD so we
need to obtain the loss labels from the model's state file. This is called now so we know
event data is being written, and therefore the most current loss label data is available
in the state file.
Loss names are added to :attr:`_loss_labels`
Parameters
----------
session_id: int
The session id that the data is being cached for
"""
if self._cache._loss_labels: # pylint:disable=protected-access
return
# Import global session here to prevent circular import
from . import Session # pylint:disable=import-outside-toplevel
loss_labels = sorted(Session.get_loss_keys(session_id=session_id))
self._loss_labels = loss_labels
logger.debug("Collated loss labels: %s", self._loss_labels)
@classmethod
def _process_event(cls, event, step):
""" Process a single Tensorflow event.
@ -670,8 +716,19 @@ class _EventParser(): # pylint:disable=too-few-public-methods
The given step `dict` with the given event data added to it.
"""
summary = event.summary.value[0]
if summary.tag in ("batch_loss", "batch_total"): # Pre tf2.3 totals were "batch_total"
step["timestamp"] = event.wall_time
return step
step.setdefault("loss", list()).append(summary.simple_value)
loss = summary.simple_value
if not loss:
# Need to convert a tensor to a float for TF2.8 logged data. This maybe due to change
# in logging or may be due to work around put in place in FS training function for the
# following bug in TF 2.8 when writing records:
# https://github.com/keras-team/keras/issues/16173
loss = float(tf.make_ndarray(summary.tensor))
step.setdefault("loss", []).append(loss)
return step

View File

@ -17,6 +17,7 @@ from threading import Event
import numpy as np
from lib.serializer import get_serializer
from lib.utils import get_backend
from .event_reader import TensorBoardLogs
@ -62,9 +63,9 @@ class GlobalSession():
def batch_sizes(self):
""" dict: The batch sizes for each session_id for the model. """
if self._state is None:
return dict()
return {}
return {int(sess_id): sess["batchsize"]
for sess_id, sess in self._state.get("sessions", dict()).items()}
for sess_id, sess in self._state.get("sessions", {}).items()}
@property
def full_summary(self):
@ -86,7 +87,7 @@ class GlobalSession():
def _load_state_file(self):
""" Load the current state file to :attr:`_state`. """
state_file = os.path.join(self._model_dir, "{}_state.json".format(self._model_name))
state_file = os.path.join(self._model_dir, f"{self._model_name}_state.json")
logger.debug("Loading State: '%s'", state_file)
serializer = get_serializer("json")
self._state = serializer.load(state_file)
@ -125,8 +126,7 @@ class GlobalSession():
self._model_dir = model_folder
self._model_name = model_name
self._load_state_file()
self._tb_logs = TensorBoardLogs(os.path.join(self._model_dir,
"{}_logs".format(self._model_name)),
self._tb_logs = TensorBoardLogs(os.path.join(self._model_dir, f"{self._model_name}_logs"),
is_training)
self._summary = SessionsSummary(self)
@ -140,7 +140,7 @@ class GlobalSession():
def clear(self):
""" Clear the currently loaded session. """
self._state = dict()
self._state = {}
self._model_dir = None
self._model_name = None
@ -173,13 +173,13 @@ class GlobalSession():
loss_dict = self._tb_logs.get_loss(session_id=session_id)
if session_id is None:
retval = dict()
retval = {}
for key in sorted(loss_dict):
for loss_key, loss in loss_dict[key].items():
retval.setdefault(loss_key, []).extend(loss)
retval = {key: np.array(val, dtype="float32") for key, val in retval.items()}
else:
retval = loss_dict.get(session_id, dict())
retval = loss_dict.get(session_id, {})
if self._is_training:
self._is_querying.clear()
@ -239,14 +239,21 @@ class GlobalSession():
The loss keys for the given session. If ``None`` is passed as session_id then a unique
list of all loss keys for all sessions is returned
"""
loss_keys = {sess_id: list(logs.keys())
for sess_id, logs in self._tb_logs.get_loss(session_id=session_id).items()}
if get_backend() == "amd":
# We can't log the graph in Tensorboard logs for AMD so need to obtain from state file
loss_keys = {int(sess_id): [name for name in session["loss_names"] if name != "total"]
for sess_id, session in self._state["sessions"].items()}
else:
loss_keys = {sess_id: list(logs.keys())
for sess_id, logs
in self._tb_logs.get_loss(session_id=session_id).items()}
if session_id is None:
retval = list(set(loss_key
for session in loss_keys.values()
for loss_key in session))
else:
retval = loss_keys[session_id]
retval = loss_keys.get(session_id)
return retval
@ -334,7 +341,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
"""
if self._per_session_stats is None:
logger.debug("Collating per session stats")
compiled = list()
compiled = []
for session_id, ts_data in self._time_stats.items():
logger.debug("Compiling session ID: %s", session_id)
if self._state is None:
@ -446,15 +453,15 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
retval = []
for summary in compiled_stats:
hrs, mins, secs = self._convert_time(summary["elapsed"])
stats = dict()
stats = {}
for key in summary:
if key not in ("start", "end", "elapsed", "rate"):
stats[key] = summary[key]
continue
stats["start"] = time.strftime("%x %X", time.localtime(summary["start"]))
stats["end"] = time.strftime("%x %X", time.localtime(summary["end"]))
stats["elapsed"] = "{}:{}:{}".format(hrs, mins, secs)
stats["rate"] = "{0:.1f}".format(summary["rate"])
stats["elapsed"] = f"{hrs}:{mins}:{secs}"
stats["rate"] = f"{summary['rate']:.1f}"
retval.append(stats)
return retval
@ -474,9 +481,9 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
"""
hrs = int(timestamp // 3600)
if hrs < 10:
hrs = "{0:02d}".format(hrs)
mins = "{0:02d}".format((int(timestamp % 3600) // 60))
secs = "{0:02d}".format((int(timestamp % 3600) % 60))
hrs = f"{hrs:02d}"
mins = f"{(int(timestamp % 3600) // 60):02d}"
secs = f"{(int(timestamp % 3600) % 60):02d}"
return hrs, mins, secs
@ -529,7 +536,7 @@ class Calculations():
self._iterations = 0
self._limit = 0
self._start_iteration = 0
self._stats = dict()
self._stats = {}
self.refresh()
logger.debug("Initialized %s", self.__class__.__name__)
@ -630,7 +637,7 @@ class Calculations():
if self._args["flatten_outliers"]:
loss = self._flatten_outliers(loss)
self.stats["raw_{}".format(loss_name)] = loss
self.stats[f"raw_{loss_name}"] = loss
self._iterations = 0 if not iterations else min(iterations)
if self._limit > 1:
@ -642,7 +649,7 @@ class Calculations():
if len(iterations) > 1:
# Crop all losses to the same number of items
if self._iterations == 0:
self.stats = {lossname: np.array(list(), dtype=loss.dtype)
self.stats = {lossname: np.array([], dtype=loss.dtype)
for lossname, loss in self.stats.items()}
else:
self.stats = {lossname: loss[:self._iterations]
@ -722,7 +729,7 @@ class Calculations():
logger.debug("Calculating totals rate")
batchsizes = _SESSION.batch_sizes
total_timestamps = _SESSION.get_timestamps(None)
rate = list()
rate = []
for sess_id in sorted(total_timestamps.keys()):
batchsize = batchsizes[sess_id]
timestamps = total_timestamps[sess_id]
@ -737,10 +744,10 @@ class Calculations():
if selection == "raw":
continue
logger.debug("Calculating: %s", selection)
method = getattr(self, "_calc_{}".format(selection))
method = getattr(self, f"_calc_{selection}")
raw_keys = [key for key in self._stats if key.startswith("raw_")]
for key in raw_keys:
selected_key = "{}_{}".format(selection, key.replace("raw_", ""))
selected_key = f"{selection}_{key.replace('raw_', '')}"
self._stats[selected_key] = method(self._stats[key])
def _calc_avg(self, data):
@ -866,7 +873,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
optimizations.
"""
# Use :func:`np.finfo(dtype).eps` if you are worried about accuracy and want to be safe.
epsilon = np.finfo(self._dtype).tiny
epsilon = np.finfo(self._dtype).tiny # pylint:disable=no-member
# If this produces an OverflowError, make epsilon larger:
retval = int(np.log(epsilon) / np.log(1 - self._alpha)) + 1
logger.debug("row_size: %s", retval)

View File

@ -63,13 +63,10 @@ class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors
return
filename = "extract_convert_preview"
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join(location,
"{}_{}.{}".format(filename,
now,
"png"))
filename = os.path.join(location, f"{filename}_{now}.png")
get_images().previewoutput[0].save(filename)
logger.debug("Saved preview to %s", filename)
print("Saved preview to {}".format(filename))
print(f"Saved preview to {filename}")
class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
@ -125,7 +122,7 @@ class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
should_update = self.update_preview.get()
for name in sortednames:
if name not in existing.keys():
if name not in existing:
self.add_child(name)
elif should_update:
tab_id = existing[name]
@ -197,19 +194,16 @@ class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors
""" Save the figure to file """
filename = self.name
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join(location,
"{}_{}.{}".format(filename,
now,
"png"))
filename = os.path.join(location, f"{filename}_{now}.png")
get_images().previewtrain[self.name][0].save(filename)
logger.debug("Saved preview to %s", filename)
print("Saved preview to {}".format(filename))
print(f"Saved preview to {filename}")
class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
""" The Graph Tab of the Display section """
def __init__(self, parent, tab_name, helptext, wait_time, command=None):
self._trace_vars = dict()
self._trace_vars = {}
super().__init__(parent, tab_name, helptext, wait_time, command)
def set_vars(self):
@ -370,6 +364,8 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
logger.trace("Loading graph")
self.display_item = Session
self._add_trace_variables()
elif Session.is_training and self.display_item is not None:
logger.trace("Graph already displayed. Nothing to do.")
else:
logger.trace("Clearing graph")
self.display_item = None
@ -384,9 +380,15 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
logger.debug("Adding graph")
existing = list(self.subnotebook_get_titles_ids().keys())
loss_keys = [key
for key in self.display_item.get_loss_keys(Session.session_ids[-1])
if key != "total"]
loss_keys = self.display_item.get_loss_keys(Session.session_ids[-1])
if not loss_keys:
# Reload if we attempt to get loss keys before data is written
logger.debug("Waiting for Session Data to become available to graph")
self.after(1000, self.display_item_process)
return
loss_keys = [key for key in loss_keys if key != "total"]
display_tabs = sorted(set(key[:-1].rstrip("_") for key in loss_keys))
for loss_key in display_tabs:
@ -472,7 +474,7 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
for name, (var, trace) in self._trace_vars.items():
logger.debug("Clearing trace from variable: %s", name)
var.trace_vdelete("w", trace)
self._trace_vars = dict()
self._trace_vars = {}
def close(self):
""" Clear the plots from RAM """

View File

@ -59,7 +59,7 @@ class DisplayPage(ttk.Frame): # pylint: disable=too-many-ancestors
@staticmethod
def set_vars():
""" Override to return a dict of page specific variables """
return dict()
return {}
def on_tab_select(self): # pylint:disable=no-self-use
""" Override for specific actions when the current tab is selected """
@ -151,7 +151,7 @@ class DisplayPage(ttk.Frame): # pylint: disable=too-many-ancestors
def subnotebook_get_titles_ids(self):
""" Return tabs ids and titles """
tabs = dict()
tabs = {}
for tab_id in range(0, self.subnotebook.index("end")):
tabs[self.subnotebook.tab(tab_id, "text")] = tab_id
logger.debug(tabs)
@ -213,11 +213,11 @@ class DisplayOptionalPage(DisplayPage): # pylint: disable=too-many-ancestors
def set_info_text(self):
""" Set waiting for display text """
if not self.vars["enabled"].get():
msg = "{} disabled".format(self.tabname.title())
msg = f"{self.tabname.title()} disabled"
elif self.vars["enabled"].get() and not self.vars["ready"].get():
msg = "Waiting for {}...".format(self.tabname)
msg = f"Waiting for {self.tabname}..."
else:
msg = "Displaying {}".format(self.tabname)
msg = f"Displaying {self.tabname}"
logger.debug(msg)
self.set_info(msg)
@ -235,7 +235,7 @@ class DisplayOptionalPage(DisplayPage): # pylint: disable=too-many-ancestors
command=self.save_items)
btnsave.pack(padx=2, side=tk.RIGHT)
Tooltip(btnsave,
text=_("Save {}(s) to file").format(self.tabname),
text=_(f"Save {self.tabname}(s) to file"),
wrap_length=200)
def add_option_enable(self):
@ -243,11 +243,11 @@ class DisplayOptionalPage(DisplayPage): # pylint: disable=too-many-ancestors
logger.debug("Adding enable option")
chkenable = ttk.Checkbutton(self.optsframe,
variable=self.vars["enabled"],
text="Enable {}".format(self.tabname),
text=f"Enable {self.tabname}",
command=self.on_chkenable_change)
chkenable.pack(side=tk.RIGHT, padx=5, anchor=tk.W)
Tooltip(chkenable,
text=_("Enable or disable {} display").format(self.tabname),
text=_(f"Enable or disable {self.tabname} display"),
wrap_length=200)
def save_items(self):

View File

@ -137,6 +137,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
"variable: %s)", self.__class__.__name__, handle_type, file_type, title,
initial_folder, initial_file, command, action, variable)
self._handletype = handle_type
self._dummy_master = self._set_dummy_master()
self._defaults = self._set_defaults()
self._kwargs = self._set_kwargs(title,
initial_folder,
@ -145,7 +146,9 @@ class FileHandler(): # pylint:disable=too-few-public-methods
command,
action,
variable)
self.return_file = getattr(self, "_{}".format(self._handletype.lower()))()
self.return_file = getattr(self, f"_{self._handletype.lower()}")()
self._remove_dummy_master()
logger.debug("Initialized %s", self.__class__.__name__)
@property
@ -184,10 +187,10 @@ class FileHandler(): # pylint:disable=too-few-public-methods
if platform.system() == "Linux":
filetypes[key] = [item
if item[0] == "All files"
else (item[0], "{} {}".format(item[1], item[1].upper()))
else (item[0], f"{item[1]} {item[1].upper()}")
for item in filetypes[key]]
if len(filetypes[key]) > 2:
multi = ["{} Files".format(key.title())]
multi = [f"{key.title()} Files"]
multi.append(" ".join([ftype[1]
for ftype in filetypes[key] if ftype[0] != "All files"]))
filetypes[key].insert(0, tuple(multi))
@ -214,6 +217,35 @@ class FileHandler(): # pylint:disable=too-few-public-methods
"rotate": "save_filename",
"slice": "save_filename"}))
@classmethod
def _set_dummy_master(cls):
""" Add an option to force black font on Linux file dialogs KDE issue that displays light
font on white background).
This is a pretty hacky solution, but tkinter does not allow direct editing of file dialogs,
so we create a dummy frame and add the foreground option there, so that the file dialog can
inherit the foreground.
Returns
-------
tkinter.Frame or ``None``
The dummy master frame for Linux systems, otherwise ``None``
"""
if platform.system().lower() == "linux":
retval = tk.Frame()
retval.option_add("*foreground", "black")
else:
retval = None
return retval
def _remove_dummy_master(self):
""" Destroy the dummy master widget on Linux systems. """
if platform.system().lower() != "linux":
return
self._dummy_master.destroy()
del self._dummy_master
self._dummy_master = None
def _set_defaults(self):
""" Set the default file type for the file dialog. Generally the first found file type
will be used, but this is overridden if it is not appropriate.
@ -264,7 +296,9 @@ class FileHandler(): # pylint:disable=too-few-public-methods
logger.debug("Setting Kwargs: (title: %s, initial_folder: %s, initial_file: '%s', "
"file_type: '%s', command: '%s': action: '%s', variable: '%s')",
title, initial_folder, initial_file, file_type, command, action, variable)
kwargs = dict()
kwargs = dict(master=self._dummy_master)
if self._handletype.lower() == "context":
self._set_context_handletype(command, action, variable)
@ -361,10 +395,10 @@ class Images():
self._pathpreview = os.path.join(PATHCACHE, "preview")
self._pathoutput = None
self._previewoutput = None
self._previewtrain = dict()
self._previewtrain = {}
self._previewcache = dict(modified=None, # cache for extract and convert
images=None,
filenames=list(),
filenames=[],
placeholder=None)
self._errcount = 0
self._icons = self._load_icons()
@ -420,7 +454,7 @@ class Images():
"""
size = get_config().user_config_dict.get("icon_size", 16)
size = int(round(size * get_config().scaling_factor))
icons = dict()
icons = {}
pathicons = os.path.join(PATHCACHE, "icons")
for fname in os.listdir(pathicons):
name, ext = os.path.splitext(fname)
@ -470,10 +504,10 @@ class Images():
logger.debug("Clearing image cache")
self._pathoutput = None
self._previewoutput = None
self._previewtrain = dict()
self._previewtrain = {}
self._previewcache = dict(modified=None, # cache for extract and convert
images=None,
filenames=list(),
filenames=[],
placeholder=None)
@staticmethod
@ -600,10 +634,10 @@ class Images():
logger.debug("num_images: %s", num_images)
if num_images == 0:
return False
samples = list()
samples = []
start_idx = len(image_files) - num_images if len(image_files) > num_images else 0
show_files = sorted(image_files, key=os.path.getctime)[start_idx:]
dropped_files = list()
dropped_files = []
for fname in show_files:
try:
img = Image.open(fname)
@ -732,7 +766,7 @@ class Images():
modified = None
if not image_files:
logger.debug("No preview to display")
self._previewtrain = dict()
self._previewtrain = {}
return
for img in image_files:
modified = os.path.getmtime(img) if modified is None else modified
@ -755,7 +789,7 @@ class Images():
self._errcount += 1
else:
logger.error("Error reading the preview file for '%s'", img)
print("Error reading the preview file for {}".format(name))
print(f"Error reading the preview file for {name}")
self._previewtrain[name] = None
def _get_current_size(self, name):
@ -1126,7 +1160,7 @@ class Config():
Additional text to be appended to the GUI title bar. Default: ``None``
"""
title = "Faceswap.py"
title += " - {}".format(text) if text is not None and text else ""
title += f" - {text}" if text is not None and text else ""
self.root.title(title)
def set_geometry(self, width, height, fullscreen=False):
@ -1154,8 +1188,7 @@ class Config():
elif fullscreen:
self.root.attributes('-zoomed', True)
else:
self.root.geometry("{}x{}+80+80".format(str(initial_dimensions[0]),
str(initial_dimensions[1])))
self.root.geometry(f"{str(initial_dimensions[0])}x{str(initial_dimensions[1])}+80+80")
logger.debug("Geometry: %sx%s", *initial_dimensions)
@ -1260,7 +1293,7 @@ class PreviewTrigger():
"""
trigger = self._trigger_files[trigger_type]
if not os.path.isfile(trigger):
with open(trigger, "w"):
with open(trigger, "w", encoding="utf8"):
pass
logger.debug("Set preview trigger: %s", trigger)

View File

@ -9,9 +9,8 @@ import numpy as np
import tensorflow as tf
from keras import backend as K
from keras import initializers
from keras.utils import get_custom_objects
from lib.utils import get_backend
from lib.utils import get_backend, get_keras_custom_objects as get_custom_objects
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -64,7 +63,7 @@ def compute_fans(shape, data_format='channels_last'):
return fan_in, fan_out
class ICNR(initializers.Initializer): # pylint: disable=invalid-name
class ICNR(initializers.Initializer): # pylint: disable=invalid-name,no-member
""" ICNR initializer for checkerboard artifact free sub pixel convolution
Parameters
@ -167,11 +166,11 @@ class ICNR(initializers.Initializer): # pylint: disable=invalid-name
config = {"scale": self.scale,
"initializer": self.initializer
}
base_config = super(ICNR, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
class ConvolutionAware(initializers.Initializer):
class ConvolutionAware(initializers.Initializer): # pylint: disable=no-member
"""
Initializer that generates orthogonal convolution filters in the Fourier space. If this
initializer is passed a shape that is not 3D or 4D, orthogonal initialization will be used.
@ -204,8 +203,8 @@ class ConvolutionAware(initializers.Initializer):
def __init__(self, eps_std=0.05, seed=None, initialized=False):
self.eps_std = eps_std
self.seed = seed
self.orthogonal = initializers.Orthogonal()
self.he_uniform = initializers.he_uniform()
self.orthogonal = initializers.Orthogonal() # pylint:disable=no-member
self.he_uniform = initializers.he_uniform() # pylint:disable=no-member
self.initialized = initialized
def __call__(self, shape, dtype=None):

View File

@ -10,16 +10,15 @@ import tensorflow as tf
import keras.backend as K
from keras.layers import InputSpec, Layer
from keras.utils import get_custom_objects
from lib.utils import get_backend
from lib.utils import get_backend, get_keras_custom_objects as get_custom_objects
if get_backend() == "amd":
from lib.plaidml_utils import pad
from keras.utils import conv_utils # pylint:disable=ungrouped-imports
else:
from tensorflow import pad
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import conv_utils # pylint:disable=no-name-in-module
class PixelShuffler(Layer):
@ -64,20 +63,22 @@ class PixelShuffler(Layer):
def __init__(self, size=(2, 2), data_format=None, **kwargs):
super().__init__(**kwargs)
if get_backend() == "amd":
self.data_format = K.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format) # pylint:disable=no-member
else:
self.data_format = conv_utils.normalize_data_format(data_format)
self.size = conv_utils.normalize_tuple(size, 2, 'size')
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
def call(self, inputs, *args, **kwargs):
"""This is where the layer's logic lives.
Parameters
----------
inputs: tensor
Input tensor, or list/tuple of input tensors
args: tuple
Additional standard keras Layer arguments
kwargs: dict
Additional keyword arguments. Unused
Additional standard keras Layer keyword arguments
Returns
-------
@ -186,7 +187,7 @@ class PixelShuffler(Layer):
"""
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(PixelShuffler, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@ -208,15 +209,17 @@ class KResizeImages(Layer):
self.size = size
self.interpolation = interpolation
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
def call(self, inputs, *args, **kwargs):
""" Call the upsample layer
Parameters
----------
inputs: tensor
Input tensor, or list/tuple of input tensors
args: tuple
Additional standard keras Layer arguments
kwargs: dict
Additional keyword arguments. Unused
Additional standard keras Layer keyword arguments
Returns
-------
@ -316,11 +319,11 @@ class SubPixelUpscaling(Layer):
"""
def __init__(self, scale_factor=2, data_format=None, **kwargs):
super(SubPixelUpscaling, self).__init__(**kwargs)
super().__init__(**kwargs)
self.scale_factor = scale_factor
if get_backend() == "amd":
self.data_format = K.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format) # pylint:disable=no-member
else:
self.data_format = conv_utils.normalize_data_format(data_format)
@ -337,15 +340,17 @@ class SubPixelUpscaling(Layer):
"""
pass # pylint: disable=unnecessary-pass
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
def call(self, inputs, *args, **kwargs):
"""This is where the layer's logic lives.
Parameters
----------
inputs: tensor
Input tensor, or list/tuple of input tensors
args: tuple
Additional standard keras Layer arguments
kwargs: dict
Additional keyword arguments. Unused
Additional standard keras Layer keyword arguments
Returns
-------
@ -460,7 +465,7 @@ class SubPixelUpscaling(Layer):
"""
config = {"scale_factor": self.scale_factor,
"data_format": self.data_format}
base_config = super(SubPixelUpscaling, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@ -592,7 +597,7 @@ class ReflectionPadding2D(Layer):
"""
config = {'stride': self.stride,
'kernel_size': self.kernel_size}
base_config = super(ReflectionPadding2D, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@ -602,9 +607,9 @@ class _GlobalPooling2D(Layer):
From keras as access to pooling is trickier in tensorflow.keras
"""
def __init__(self, data_format=None, **kwargs):
super(_GlobalPooling2D, self).__init__(**kwargs)
super().__init__(**kwargs)
if get_backend() == "amd":
self.data_format = K.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format) # pylint:disable=no-member
else:
self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=4)
@ -621,37 +626,41 @@ class _GlobalPooling2D(Layer):
return (input_shape[0], input_shape[3])
return (input_shape[0], input_shape[1])
def call(self, inputs, **kwargs):
def call(self, inputs, *args, **kwargs):
""" Override to call the layer.
Parameters
----------
inputs: Tensor
The input to the layer
args: tuple
Additional standard keras Layer arguments
kwargs: dict
Additional keyword arguments
Additional standard keras Layer keyword arguments
"""
raise NotImplementedError
def get_config(self):
""" Set the Keras config """
config = {'data_format': self.data_format}
base_config = super(_GlobalPooling2D, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
class GlobalMinPooling2D(_GlobalPooling2D):
"""Global minimum pooling operation for spatial data. """
def call(self, inputs, **kwargs):
def call(self, inputs, *args, **kwargs):
"""This is where the layer's logic lives.
Parameters
----------
inputs: tensor
Input tensor, or list/tuple of input tensors
args: tuple
Additional standard keras Layer arguments
kwargs: dict
Additional keyword arguments
Additional standard keras Layer keyword arguments
Returns
-------
@ -668,15 +677,17 @@ class GlobalMinPooling2D(_GlobalPooling2D):
class GlobalStdDevPooling2D(_GlobalPooling2D):
"""Global standard deviation pooling operation for spatial data. """
def call(self, inputs, **kwargs):
def call(self, inputs, *args, **kwargs):
"""This is where the layer's logic lives.
Parameters
----------
inputs: tensor
Input tensor, or list/tuple of input tensors
args: tuple
Additional standard keras Layer arguments
kwargs: dict
Additional keyword arguments
Additional standard keras Layer keyword arguments
Returns
-------
@ -702,7 +713,7 @@ class L2_normalize(Layer): # pylint:disable=invalid-name
"""
def __init__(self, axis, **kwargs):
self.axis = axis
super(L2_normalize, self).__init__(**kwargs)
super().__init__(**kwargs)
def call(self, inputs): # pylint:disable=arguments-differ
"""This is where the layer's logic lives.
@ -736,7 +747,7 @@ class L2_normalize(Layer): # pylint:disable=invalid-name
dict
A python dictionary containing the layer configuration
"""
config = super(L2_normalize, self).get_config()
config = super().get_config()
config["axis"] = self.axis
return config

View File

@ -7,17 +7,17 @@ import logging
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.engine import compile_utils
from tensorflow.python.keras.engine import compile_utils # pylint:disable=no-name-in-module
from keras import backend as K
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
class DSSIMObjective(tf.keras.losses.Loss):
class DSSIMObjective(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
""" DSSIM Loss Function
Difference of Structural Similarity (DSSIM loss function). Clipped between 0 and 0.5
Difference of Structural Similarity (DSSIM loss function).
Parameters
----------
@ -25,66 +25,24 @@ class DSSIMObjective(tf.keras.losses.Loss):
Parameter of the SSIM. Default: `0.01`
k_2: float, optional
Parameter of the SSIM. Default: `0.03`
kernel_size: int, optional
Size of the sliding window Default: `3`
filter_size: int, optional
size of gaussian filter Default: `11`
filter_sigma: float, optional
Width of gaussian filter Default: `1.5`
max_value: float, optional
Max value of the output. Default: `1.0`
Notes
------
You should add a regularization term like a l2 loss in addition to this one.
References
----------
https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
MIT License
Copyright (c) 2017 Fariz Rahman
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
def __init__(self, k_1=0.01, k_2=0.03, kernel_size=3, max_value=1.0):
def __init__(self, k_1=0.01, k_2=0.03, filter_size=11, filter_sigma=1.5, max_value=1.0):
super().__init__(name="DSSIMObjective")
self.kernel_size = kernel_size
self.filter_size = filter_size
self.filter_sigma = filter_sigma
self.k_1 = k_1
self.k_2 = k_2
self.max_value = max_value
self.c_1 = (self.k_1 * self.max_value) ** 2
self.c_2 = (self.k_2 * self.max_value) ** 2
self.dim_ordering = K.image_data_format()
@staticmethod
def _int_shape(input_tensor):
""" Returns the shape of tensor or variable as a tuple of int or None entries.
Parameters
----------
input_tensor: tensor or variable
The input to return the shape for
Returns
-------
tuple
A tuple of integers (or None entries)
"""
return K.int_shape(input_tensor)
def call(self, y_true, y_pred):
""" Call the DSSIM Loss Function.
@ -100,104 +58,19 @@ class DSSIMObjective(tf.keras.losses.Loss):
-------
tensor
The DSSIM Loss value
Notes
-----
There are additional parameters for this function. some of the 'modes' for edge behavior
do not yet have a gradient definition in the Theano tree and cannot be used for learning
"""
kernel = [self.kernel_size, self.kernel_size]
y_true = K.reshape(y_true, [-1] + list(self._int_shape(y_pred)[1:]))
y_pred = K.reshape(y_pred, [-1] + list(self._int_shape(y_pred)[1:]))
patches_pred = self.extract_image_patches(y_pred,
kernel,
kernel,
'valid',
self.dim_ordering)
patches_true = self.extract_image_patches(y_true,
kernel,
kernel,
'valid',
self.dim_ordering)
# Get mean
u_true = K.mean(patches_true, axis=-1)
u_pred = K.mean(patches_pred, axis=-1)
# Get variance
var_true = K.var(patches_true, axis=-1)
var_pred = K.var(patches_pred, axis=-1)
# Get standard deviation
covar_true_pred = K.mean(
patches_true * patches_pred, axis=-1) - u_true * u_pred
ssim = (2 * u_true * u_pred + self.c_1) * (
2 * covar_true_pred + self.c_2)
denom = (K.square(u_true) + K.square(u_pred) + self.c_1) * (
var_pred + var_true + self.c_2)
ssim /= denom # no need for clipping, c_1 + c_2 make the denorm non-zero
return (1.0 - ssim) / 2.0
@staticmethod
def _preprocess_padding(padding):
"""Convert keras padding to tensorflow padding.
Parameters
----------
padding: string,
`"same"` or `"valid"`.
Returns
-------
str
`"SAME"` or `"VALID"`.
Raises
------
ValueError
If `padding` is invalid.
"""
if padding == 'same':
padding = 'SAME'
elif padding == 'valid':
padding = 'VALID'
else:
raise ValueError('Invalid padding:', padding)
return padding
def extract_image_patches(self, input_tensor, k_sizes, s_sizes,
padding='same', data_format='channels_last'):
""" Extract the patches from an image.
Parameters
----------
input_tensor: tensor
The input image
k_sizes: tuple
2-d tuple with the kernel size
s_sizes: tuple
2-d tuple with the strides size
padding: str, optional
`"same"` or `"valid"`. Default: `"same"`
data_format: str, optional.
`"channels_last"` or `"channels_first"`. Default: `"channels_last"`
Returns
-------
The (k_w, k_h) patches extracted
Tensorflow ==> (batch_size, w, h, k_w, k_h, c)
Theano ==> (batch_size, w, h, c, k_w, k_h)
"""
kernel = [1, k_sizes[0], k_sizes[1], 1]
strides = [1, s_sizes[0], s_sizes[1], 1]
padding = self._preprocess_padding(padding)
if data_format == 'channels_first':
input_tensor = K.permute_dimensions(input_tensor, (0, 2, 3, 1))
patches = tf.image.extract_patches(input_tensor, kernel, strides, [1, 1, 1, 1], padding)
return patches
ssim = tf.image.ssim(y_true,
y_pred,
self.max_value,
filter_size=self.filter_size,
filter_sigma=self.filter_sigma,
k1=self.k_1,
k2=self.k_2)
dssim_loss = 1. - ssim
return dssim_loss
class GeneralizedLoss(tf.keras.losses.Loss):
class GeneralizedLoss(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
""" Generalized function used to return a large variety of mathematical loss functions.
The primary benefit is a smooth, differentiable version of L1 loss.
@ -247,10 +120,11 @@ class GeneralizedLoss(tf.keras.losses.Loss):
return loss
class LInfNorm(tf.keras.losses.Loss):
class LInfNorm(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
""" Calculate the L-inf norm as a loss function. """
def call(self, y_true, y_pred):
@classmethod
def call(cls, y_true, y_pred):
""" Call the L-inf norm loss function.
Parameters
@ -271,7 +145,7 @@ class LInfNorm(tf.keras.losses.Loss):
return loss
class GradientLoss(tf.keras.losses.Loss):
class GradientLoss(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
""" Gradient Loss Function.
Calculates the first and second order gradient difference between pixels of an image in the x
@ -392,7 +266,7 @@ class GradientLoss(tf.keras.losses.Loss):
return (xy_out1 - xy_out2) * 0.25
class GMSDLoss(tf.keras.losses.Loss):
class GMSDLoss(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
""" Gradient Magnitude Similarity Deviation Loss.
Improved image quality metric over MS-SSIM with easier calculations
@ -486,7 +360,9 @@ class GMSDLoss(tf.keras.losses.Loss):
# Use depth-wise convolution to calculate edge maps per channel.
# Output tensor has shape [batch_size, h, w, d * num_kernels].
pad_sizes = [[0, 0], [2, 2], [2, 2], [0, 0]]
padded = tf.pad(image, pad_sizes, mode='REFLECT')
padded = tf.pad(image, # pylint:disable=unexpected-keyword-arg,no-value-for-parameter
pad_sizes,
mode='REFLECT')
output = K.depthwise_conv2d(padded, kernels)
if not magnitude: # direction of edges

View File

@ -7,15 +7,16 @@ import inspect
from keras.layers import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils import get_custom_objects
from lib.utils import get_backend
from lib.utils import get_backend, get_keras_custom_objects as get_custom_objects
if get_backend() == "amd":
from keras.backend import normalize_data_format # pylint:disable=ungrouped-imports
from keras.backend \
import normalize_data_format # pylint:disable=ungrouped-imports,no-name-in-module
else:
from tensorflow.python.keras.utils.conv_utils import normalize_data_format
# pylint:disable=no-name-in-module
from tensorflow.python.keras.utils.conv_utils \
import normalize_data_format # pylint:disable=no-name-in-module
class InstanceNormalization(Layer):
@ -61,6 +62,7 @@ class InstanceNormalization(Layer):
- Instance Normalization: The Missing Ingredient for Fast Stylization - \
https://arxiv.org/abs/1607.08022
"""
# pylint:disable=too-many-instance-attributes,too-many-arguments
def __init__(self,
axis=None,
epsilon=1e-3,
@ -348,6 +350,7 @@ class GroupNormalization(Layer):
----------
Shaoanlu GAN: https://github.com/shaoanlu/faceswap-GAN
"""
# pylint:disable=too-many-instance-attributes
def __init__(self, axis=-1, gamma_init='one', beta_init='zero', gamma_regularizer=None,
beta_regularizer=None, epsilon=1e-6, group=32, data_format=None, **kwargs):
self.beta = None

View File

@ -4,10 +4,14 @@ import inspect
import sys
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow.keras.backend as K # pylint:disable=no-name-in-module,import-error
# tf.keras has a LayerNormaliztion implementation
from tensorflow.keras.layers import Layer, LayerNormalization # noqa pylint:disable=unused-import
from tensorflow.keras.utils import get_custom_objects
# pylint:disable=unused-import
from tensorflow.keras.layers import ( # noqa pylint:disable=no-name-in-module,import-error
Layer,
LayerNormalization)
from lib.utils import get_keras_custom_objects as get_custom_objects
class RMSNormalization(Layer):
@ -117,9 +121,10 @@ class RMSNormalization(Layer):
mean_square = K.mean(K.square(inputs), axis=self.axis, keepdims=True)
else:
partial_size = int(layer_size * self.partial)
partial_x, _ = tf.split(inputs,
[partial_size, layer_size - partial_size],
axis=self.axis)
partial_x, _ = tf.split( # pylint:disable=redundant-keyword-arg,no-value-for-parameter
inputs,
[partial_size, layer_size - partial_size],
axis=self.axis)
mean_square = K.mean(K.square(partial_x), axis=self.axis, keepdims=True)
recip_square_root = tf.math.rsqrt(mean_square + self.epsilon)

View File

@ -4,7 +4,7 @@ import inspect
import sys
from keras import backend as K
from keras.optimizers import Optimizer
from keras.optimizers import Optimizer, Adam, Nadam, RMSprop # noqa pylint:disable=unused-import
from keras.utils import get_custom_objects

View File

@ -8,7 +8,10 @@ import inspect
import sys
import tensorflow as tf
from keras.utils import get_custom_objects
from tensorflow.keras.optimizers import ( # noqa pylint:disable=no-name-in-module,unused-import,import-error
Adam, Nadam, RMSprop)
from lib.utils import get_keras_custom_objects as get_custom_objects
class AdaBelief(tf.keras.optimizers.Optimizer):
@ -128,6 +131,7 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-14,
weight_decay=0.0, rectify=True, amsgrad=False, sma_threshold=5.0, total_steps=0,
warmup_proportion=0.1, min_lr=0.0, name="AdaBeliefOptimizer", **kwargs):
# pylint:disable=too-many-arguments
super().__init__(name, **kwargs)
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
self._set_hyper("beta_1", beta_1)
@ -196,7 +200,7 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
return wd_t
def _resource_apply_dense(self, grad, handle, apply_state=None):
# pylint:disable=too-many-locals
# pylint:disable=too-many-locals,unused-argument
""" Add ops to apply dense gradients to the variable handle.
Parameters
@ -274,7 +278,7 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
return tf.group(*updates)
def _resource_apply_sparse(self, grad, handle, indices, apply_state=None):
# pylint:disable=too-many-locals
# pylint:disable=too-many-locals, unused-argument
""" Add ops to apply sparse gradients to the variable handle.
Similar to _apply_sparse, the indices argument to this method has been de-duplicated.
@ -324,7 +328,7 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
m_corr_t = m_t / (1.0 - beta_1_power)
var_v = self.get_slot(handle, "v")
m_t_indices = tf.gather(m_t, indices)
m_t_indices = tf.gather(m_t, indices) # pylint:disable=no-value-for-parameter
v_scaled_g_values = tf.math.square(grad - m_t_indices) * (1 - beta_2_t)
v_t = var_v.assign(var_v * beta_2_t + epsilon_t, use_locking=self._use_locking)
v_t = self._resource_scatter_add(var_v, indices, v_scaled_g_values)
@ -355,7 +359,9 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
var_update = self._resource_scatter_add(handle,
indices,
tf.gather(tf.math.negative(lr_t) * var_t, indices))
tf.gather( # pylint:disable=no-value-for-parameter
tf.math.negative(lr_t) * var_t,
indices))
updates = [var_update, m_t, v_t]
if self.amsgrad:
@ -391,6 +397,6 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
# Update layers into Keras custom objects
for name, obj in inspect.getmembers(sys.modules[__name__]):
for _name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and obj.__module__ == __name__:
get_custom_objects().update({name: obj})
get_custom_objects().update({_name: obj})

View File

@ -22,6 +22,7 @@ _image_extensions = [ # pylint:disable=invalid-name
_video_extensions = [ # pylint:disable=invalid-name
".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv",
".ts", ".vob"]
_TF_VERS = None
class _Backend(): # pylint:disable=too-few-public-methods
@ -60,8 +61,7 @@ class _Backend(): # pylint:disable=too-few-public-methods
# Check if environment variable is set, if so use that
if "FACESWAP_BACKEND" in os.environ:
fs_backend = os.environ["FACESWAP_BACKEND"].lower()
print("Setting Faceswap backend from environment variable to "
"{}".format(fs_backend.upper()))
print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}")
return fs_backend
# Intercept for sphinx docs build
if sys.argv[0].endswith("sphinx-build"):
@ -70,7 +70,7 @@ class _Backend(): # pylint:disable=too-few-public-methods
self._configure_backend()
while True:
try:
with open(self._config_file, "r") as cnf:
with open(self._config_file, "r", encoding="utf8") as cnf:
config = json.load(cnf)
break
except json.decoder.JSONDecodeError:
@ -80,7 +80,7 @@ class _Backend(): # pylint:disable=too-few-public-methods
if fs_backend is None or fs_backend.lower() not in self._backends.values():
fs_backend = self._configure_backend()
if current_process().name == "MainProcess":
print("Setting Faceswap backend to {}".format(fs_backend.upper()))
print(f"Setting Faceswap backend to {fs_backend.upper()}")
return fs_backend.lower()
def _configure_backend(self):
@ -95,14 +95,14 @@ class _Backend(): # pylint:disable=too-few-public-methods
while True:
selection = input("1: AMD, 2: CPU, 3: NVIDIA: ")
if selection not in ("1", "2", "3"):
print("'{}' is not a valid selection. Please try again".format(selection))
print(f"'{selection}' is not a valid selection. Please try again")
continue
break
fs_backend = self._backends[selection].lower()
config = {"backend": fs_backend}
with open(self._config_file, "w") as cnf:
with open(self._config_file, "w", encoding="utf8") as cnf:
json.dump(config, cnf)
print("Faceswap config written to: {}".format(self._config_file))
print(f"Faceswap config written to: {self._config_file}")
return fs_backend
@ -132,6 +132,32 @@ def set_backend(backend):
_FS_BACKEND = backend.lower()
def get_tf_version():
""" Obtain the major.minor version of currently installed Tensorflow.
Returns
-------
float
The currently installed tensorflow version
"""
global _TF_VERS # pylint:disable=global-statement
if _TF_VERS is None:
import tensorflow as tf # pylint:disable=import-outside-toplevel
_TF_VERS = float(".".join(tf.__version__.split(".")[:2])) # pylint:disable=no-member
return _TF_VERS
def get_keras_custom_objects():
""" Wrapper to obtain keras.utils.get_custom_objects from correct location depending on
backend used and tensorflow version. """
# pylint:disable=no-name-in-module,import-outside-toplevel
if get_backend() == "amd" or get_tf_version() < 2.8:
from keras.utils import get_custom_objects
else:
from keras.utils.generic_utils import get_custom_objects
return get_custom_objects()
def get_folder(path, make_folder=True):
""" Return a path to a folder, creating it if it doesn't exist
@ -176,7 +202,7 @@ def get_image_paths(directory, extension=None):
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
image_extensions = _image_extensions if extension is None else [extension]
dir_contents = list()
dir_contents = []
if not os.path.exists(directory):
logger.debug("Creating folder: '%s'", directory)
@ -242,7 +268,7 @@ def full_path_split(path):
>>> ["foo", "baz", "bar"]
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
allparts = list()
allparts = []
while True:
parts = os.path.split(path)
if parts[0] == path: # sentinel for absolute paths
@ -297,9 +323,9 @@ def deprecation_warning(function, additional_info=None):
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("func_name: %s, additional_info: %s", function, additional_info)
msg = "{} has been deprecated and will be removed from a future update.".format(function)
msg = f"{function} has been deprecated and will be removed from a future update."
if additional_info is not None:
msg += " {}".format(additional_info)
msg += f" {additional_info}"
logger.warning(msg)
@ -355,7 +381,7 @@ class FaceswapError(Exception):
pass # pylint:disable=unnecessary-pass
class GetModel(): # Pylint:disable=too-few-public-methods
class GetModel(): # pylint:disable=too-few-public-methods
""" Check for models in their cache path.
If available, return the path, if not available, get, unzip and install model
@ -428,7 +454,7 @@ class GetModel(): # Pylint:disable=too-few-public-methods
@property
def _model_zip_path(self):
""" str: The full path to downloaded zip file. """
retval = os.path.join(self._cache_dir, "{}.zip".format(self._model_full_name))
retval = os.path.join(self._cache_dir, f"{self._model_full_name}.zip")
self.logger.trace(retval)
return retval
@ -462,8 +488,8 @@ class GetModel(): # Pylint:disable=too-few-public-methods
@property
def _url_download(self):
""" strL Base download URL for models. """
tag = "v{}.{}.{}".format(self._url_section, self._git_model_id, self._model_version)
retval = "{}/{}/{}.zip".format(self._url_base, tag, self._model_full_name)
tag = f"v{self._url_section}.{self._git_model_id}.{self._model_version}"
retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip"
self.logger.trace("Download url: %s", retval)
return retval
@ -493,11 +519,11 @@ class GetModel(): # Pylint:disable=too-few-public-methods
downloaded_size = self._url_partial_size
req = urllib.request.Request(self._url_download)
if downloaded_size != 0:
req.add_header("Range", "bytes={}-".format(downloaded_size))
response = urllib.request.urlopen(req, timeout=10)
self.logger.debug("header info: {%s}", response.info())
self.logger.debug("Return Code: %s", response.getcode())
self._write_zipfile(response, downloaded_size)
req.add_header("Range", f"bytes={downloaded_size}-")
with urllib.request.urlopen(req, timeout=10) as response:
self.logger.debug("header info: {%s}", response.info())
self.logger.debug("Return Code: %s", response.getcode())
self._write_zipfile(response, downloaded_size)
break
except (socket_error, socket_timeout,
urllib.error.HTTPError, urllib.error.URLError) as err:
@ -548,8 +574,8 @@ class GetModel(): # Pylint:disable=too-few-public-methods
""" Unzip the model file to the cache folder """
self.logger.info("Extracting: '%s'", self._model_name)
try:
zip_file = zipfile.ZipFile(self._model_zip_path, "r")
self._write_model(zip_file)
with zipfile.ZipFile(self._model_zip_path, "r") as zip_file:
self._write_model(zip_file)
except Exception as err: # pylint:disable=broad-except
self.logger.error("Unable to extract model file: %s", str(err))
sys.exit(1)

View File

@ -20,13 +20,12 @@ from keras import losses as k_losses
from keras import backend as K
from keras.layers import Input
from keras.models import load_model, Model as KModel
from keras.optimizers import Adam, Nadam, RMSprop
from lib.serializer import get_serializer
from lib.model.backup_restore import Backup
from lib.model import losses, optimizers
from lib.model.nn_blocks import set_config as set_nnblock_config
from lib.utils import get_backend, FaceswapError
from lib.utils import get_backend, get_tf_version, FaceswapError
from plugins.train._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -264,13 +263,14 @@ class ModelBase():
return
if len(multiple_models) == 1:
msg = ("You have requested to train with the '{}' plugin, but a model file for the "
"'{}' plugin already exists in the folder '{}'.\nPlease select a different "
"model folder.".format(self.name, multiple_models[0], self.model_dir))
msg = (f"You have requested to train with the '{self.name}' plugin, but a model file "
f"for the '{multiple_models[0]}' plugin already exists in the folder "
f"'{self.model_dir}'.\nPlease select a different model folder.")
else:
msg = ("There are multiple plugin types ('{}') stored in the model folder '{}'. This "
"is not supported.\nPlease split the model files into their own folders before "
"proceeding".format("', '".join(multiple_models), self.model_dir))
ptypes = "', '".join(multiple_models)
msg = (f"There are multiple plugin types ('{ptypes}') stored in the model folder '"
f"{self.model_dir}'. This is not supported.\nPlease split the model files into "
"their own folders before proceeding")
raise FaceswapError(msg)
def build(self):
@ -311,11 +311,11 @@ class ModelBase():
if not all(os.path.isfile(os.path.join(self.model_dir, fname))
for fname in self._legacy_mapping()):
return
archive_dir = "{}_TF1_Archived".format(self.model_dir)
archive_dir = f"{self.model_dir}_TF1_Archived"
if os.path.exists(archive_dir):
raise FaceswapError("We need to update your model files for use with Tensorflow 2.x, "
"but the archive folder already exists. Please remove the "
"following folder to continue: '{}'".format(archive_dir))
f"following folder to continue: '{archive_dir}'")
logger.info("Updating legacy models for Tensorflow 2.x")
logger.info("Your Tensorflow 1.x models will be archived in the following location: '%s'",
@ -371,7 +371,7 @@ class ModelBase():
input_shapes = [self.input_shape, self.input_shape]
else:
input_shapes = self.input_shape
inputs = [Input(shape=shape, name="face_in_{}".format(side))
inputs = [Input(shape=shape, name=f"face_in_{side}")
for side, shape in zip(("a", "b"), input_shapes)]
logger.debug("inputs: %s", inputs)
return inputs
@ -450,7 +450,7 @@ class ModelBase():
seen = {name: 0 for name in set(self._model.output_names)}
new_names = []
for name in self._model.output_names:
new_names.append("{}_{}".format(name, seen[name]))
new_names.append(f"{name}_{seen[name]}")
seen[name] += 1
logger.debug("Output names rewritten: (old: %s, new: %s)",
self._model.output_names, new_names)
@ -510,7 +510,7 @@ class _IO():
@property
def _filename(self):
"""str: The filename for this model."""
return os.path.join(self._model_dir, "{}.h5".format(self._plugin.name))
return os.path.join(self._model_dir, f"{self._plugin.name}.h5")
@property
def model_exists(self):
@ -567,7 +567,7 @@ class _IO():
"You can try to load the model again but if the problem persists you "
"should use the Restore Tool to restore your model from backup.\n"
f"Original error: {str(err)}")
raise FaceswapError(msg)
raise FaceswapError(msg) from err
raise err
except KeyError as err:
if "unable to open object" in str(err).lower():
@ -576,7 +576,7 @@ class _IO():
"You can try to load the model again but if the problem persists you "
"should use the Restore Tool to restore your model from backup.\n"
f"Original error: {str(err)}")
raise FaceswapError(msg)
raise FaceswapError(msg) from err
raise err
logger.info("Loaded model from disk: '%s'", self._filename)
@ -605,9 +605,9 @@ class _IO():
msg = "[Saved models]"
if save_averages:
lossmsg = ["face_{}: {:.5f}".format(side, avg)
lossmsg = [f"face_{side}: {avg:.5f}"
for side, avg in zip(("a", "b"), save_averages)]
msg += " - Average loss since last save: {}".format(", ".join(lossmsg))
msg += f" - Average loss since last save: {', '.join(lossmsg)}"
logger.info(msg)
def _get_save_averages(self):
@ -699,12 +699,11 @@ class _Settings():
logger.debug("Initializing %s: (arguments: %s, mixed_precision: %s, allow_growth: %s, "
"is_predict: %s)", self.__class__.__name__, arguments, mixed_precision,
allow_growth, is_predict)
self._tf_version = [int(i) for i in tf.__version__.split(".")[:2]]
self._set_tf_settings(allow_growth, arguments.exclude_gpus)
use_mixed_precision = not is_predict and mixed_precision and get_backend() == "nvidia"
# Mixed precision moved out of experimental in tensorflow 2.4
if use_mixed_precision and self._tf_version[0] == 2 and self._tf_version[1] < 4:
if use_mixed_precision and get_tf_version() < 2.4:
self._mixed_precision = tf.keras.mixed_precision.experimental
elif use_mixed_precision:
self._mixed_precision = tf.keras.mixed_precision
@ -743,9 +742,8 @@ class _Settings():
"""
# tensorflow versions < 2.4 had different kwargs where scaling needs to be explicitly
# defined
vers = self._tf_version
kwargs = dict(loss_scale="dynamic") if vers[0] == 2 and vers[1] < 4 else dict()
logger.debug("tf version: %s, kwargs: %s", vers, kwargs)
kwargs = dict(loss_scale="dynamic") if get_tf_version() < 2.4 else {}
logger.debug("tf version: %s, kwargs: %s", get_tf_version(), kwargs)
return self._mixed_precision.LossScaleOptimizer(optimizer, **kwargs)
@classmethod
@ -821,10 +819,11 @@ class _Settings():
return False
logger.info("Enabling Mixed Precision Training.")
if exclude_gpus and self._tf_version[0] == 2 and self._tf_version[1] == 2:
if exclude_gpus and get_tf_version() == 2.2:
# TODO remove this hacky fix to disable mixed precision compatibility testing when
# tensorflow 2.2 support dropped
# pylint:disable=import-outside-toplevel,protected-access,import-error
# pylint:disable=import-outside-toplevel,protected-access
# pylint:disable=import-error,no-name-in-module
from tensorflow.python.keras.mixed_precision.experimental import \
device_compatibility_check
logger.debug("Overriding tensorflow _logged_compatibility_check parameter. Initial "
@ -833,7 +832,7 @@ class _Settings():
logger.debug("New value: %s", device_compatibility_check._logged_compatibility_check)
policy = self._mixed_precision.Policy('mixed_float16')
if self._tf_version[0] == 2 and self._tf_version[1] < 4:
if get_tf_version() < 2.4:
self._mixed_precision.set_policy(policy)
else:
self._mixed_precision.set_global_policy(policy)
@ -1102,9 +1101,11 @@ class _Optimizer(): # pylint:disable=too-few-public-methods
optimizer, learning_rate, clipnorm, epsilon, arguments)
valid_optimizers = {"adabelief": (optimizers.AdaBelief,
dict(beta_1=0.5, beta_2=0.99, epsilon=epsilon)),
"adam": (Adam, dict(beta_1=0.5, beta_2=0.99, epsilon=epsilon)),
"nadam": (Nadam, dict(beta_1=0.5, beta_2=0.99, epsilon=epsilon)),
"rms-prop": (RMSprop, dict(epsilon=epsilon))}
"adam": (optimizers.Adam,
dict(beta_1=0.5, beta_2=0.99, epsilon=epsilon)),
"nadam": (optimizers.Nadam,
dict(beta_1=0.5, beta_2=0.99, epsilon=epsilon)),
"rms-prop": (optimizers.RMSprop, dict(epsilon=epsilon))}
self._optimizer, self._kwargs = valid_optimizers[optimizer]
self._configure(learning_rate, clipnorm, arguments)
@ -1173,7 +1174,7 @@ class _Loss():
self._uses_l2_reg = ["ssim", "gmsd"]
self._inputs = None
self._names = []
self._funcs = dict()
self._funcs = {}
logger.debug("Initialized: %s", self.__class__.__name__)
@property
@ -1248,7 +1249,7 @@ class _Loss():
side, output_names, output_shapes, output_types)
self._names.extend(["{}_{}{}".format(name, side,
"" if output_types.count(name) == 1
else "_{}".format(idx))
else f"_{idx}")
for idx, name in enumerate(output_types)])
logger.debug(self._names)
@ -1354,13 +1355,13 @@ class State():
"config_changeable_items: '%s', no_logs: %s", self.__class__.__name__,
model_dir, model_name, config_changeable_items, no_logs)
self._serializer = get_serializer("json")
filename = "{}_state.{}".format(model_name, self._serializer.file_extension)
filename = f"{model_name}_state.{self._serializer.file_extension}"
self._filename = os.path.join(model_dir, filename)
self._name = model_name
self._iterations = 0
self._sessions = dict()
self._lowest_avg_loss = dict()
self._config = dict()
self._sessions = {}
self._lowest_avg_loss = {}
self._config = {}
self._load(config_changeable_items)
self._session_id = self._new_session_id()
self._create_new_session(no_logs, config_changeable_items)
@ -1473,10 +1474,10 @@ class State():
return
state = self._serializer.load(self._filename)
self._name = state.get("name", self._name)
self._sessions = state.get("sessions", dict())
self._lowest_avg_loss = state.get("lowest_avg_loss", dict())
self._sessions = state.get("sessions", {})
self._lowest_avg_loss = state.get("lowest_avg_loss", {})
self._iterations = state.get("iterations", 0)
self._config = state.get("config", dict())
self._config = state.get("config", {})
logger.debug("Loaded state: %s", state)
self._replace_config(config_changeable_items)
@ -1670,7 +1671,7 @@ class _Inference(): # pylint:disable=too-few-public-methods
logger.debug("Compiling inference model. saved_model: %s", saved_model)
struct = self._get_filtered_structure()
model_inputs = self._get_inputs(saved_model.inputs)
compiled_layers = dict()
compiled_layers = {}
for layer in saved_model.layers:
if layer.name not in struct:
logger.debug("Skipping unused layer: '%s'", layer.name)
@ -1704,7 +1705,7 @@ class _Inference(): # pylint:disable=too-few-public-methods
logger.debug("Compiling layer '%s': layer inputs: %s", layer.name, layer_inputs)
model = layer(layer_inputs)
compiled_layers[layer.name] = model
retval = KerasModel(model_inputs, model, name="{}_inference".format(saved_model.name))
retval = KerasModel(model_inputs, model, name=f"{saved_model.name}_inference")
logger.debug("Compiled inference model '%s': %s", retval.name, retval)
return retval

View File

@ -16,10 +16,11 @@ import cv2
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import errors_impl as tf_errors
from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
errors_impl as tf_errors)
from lib.training import TrainingDataGenerator
from lib.utils import FaceswapError, get_backend, get_folder, get_image_paths
from lib.utils import FaceswapError, get_backend, get_folder, get_image_paths, get_tf_version
from plugins.train._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -129,8 +130,8 @@ class TrainerBase():
logger.debug("Setting up TensorBoard Logging")
log_dir = os.path.join(str(self._model.model_dir),
"{}_logs".format(self._model.name),
"session_{}".format(self._model.state.session_id))
f"{self._model.name}_logs",
f"session_{self._model.state.session_id}")
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=log_dir,
histogram_freq=0, # Must be 0 or hangs
write_graph=get_backend() != "amd",
@ -251,7 +252,16 @@ class TrainerBase():
logger.trace("Updating TensorBoard log")
logs = {log[0]: log[1]
for log in zip(self._model.state.loss_names, loss)}
self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs)
if get_tf_version() == 2.8:
# Bug in TF 2.8 where batch recording got deleted.
# ref: https://github.com/keras-team/keras/issues/16173
for name, value in logs.items():
tf.summary.scalar(
"batch_" + name,
value,
step=self._model._model._train_counter) # pylint:disable=protected-access
def _collate_and_store_loss(self, loss):
""" Collate the loss into totals for each side.
@ -297,11 +307,11 @@ class TrainerBase():
The loss for each side. List should contain 2 ``floats`` side "a" in position 0 and
side "b" in position `.
"""
output = ", ".join(["Loss {}: {:.5f}".format(side, side_loss)
output = ", ".join([f"Loss {side}: {side_loss:.5f}"
for side, side_loss in zip(("A", "B"), loss)])
timestamp = time.strftime("%H:%M:%S")
output = "[{}] [#{:05d}] {}".format(timestamp, self._model.iterations, output)
print("\r{}".format(output), end="")
output = f"[{timestamp}] [#{self._model.iterations:05d}] {output}"
print(f"\r{output}", end="")
def clear_tensorboard(self):
""" Stop Tensorboard logging.
@ -335,14 +345,14 @@ class _Feeder():
self._model = model
self._images = images
self._config = config
self._target = dict()
self._samples = dict()
self._masks = dict()
self._target = {}
self._samples = {}
self._masks = {}
self._feeds = {side: self._load_generator(idx).minibatch_ab(images[side], batch_size, side)
for idx, side in enumerate(("a", "b"))}
self._display_feeds = dict(preview=self._set_preview_feed(), timelapse=dict())
self._display_feeds = dict(preview=self._set_preview_feed(), timelapse={})
logger.debug("Initialized %s:", self.__class__.__name__)
def _load_generator(self, output_index):
@ -385,7 +395,7 @@ class _Feeder():
The side ("a" or "b") as key, :class:`~lib.training_data.TrainingDataGenerator` as
value.
"""
retval = dict()
retval = {}
for idx, side in enumerate(("a", "b")):
logger.debug("Setting preview feed: (side: '%s')", side)
preview_images = self._config.get("preview_images", 14)
@ -484,9 +494,9 @@ class _Feeder():
should not be generated, in which case currently stored previews should be deleted.
"""
if not do_preview:
self._samples = dict()
self._target = dict()
self._masks = dict()
self._samples = {}
self._target = {}
self._masks = {}
return
logger.debug("Generating preview")
for side in ("a", "b"):
@ -523,7 +533,7 @@ class _Feeder():
"""
num_images = self._config.get("preview_images", 14)
num_images = min(batch_size, num_images) if batch_size is not None else num_images
retval = dict()
retval = {}
for side in ("a", "b"):
logger.debug("Compiling samples: (side: '%s', samples: %s)", side, num_images)
side_images = images[side] if images is not None else self._target[side]
@ -544,9 +554,9 @@ class _Feeder():
:class:`numpy.ndarrays` for creating a time-lapse frame
"""
batchsizes = []
samples = dict()
images = dict()
masks = dict()
samples = {}
images = {}
masks = {}
for side in ("a", "b"):
batch = next(self._display_feeds["timelapse"][side])
batchsizes.append(len(batch["samples"]))
@ -607,7 +617,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
self.__class__.__name__, model, coverage_ratio)
self._model = model
self._display_mask = model.config["learn_mask"] or model.config["penalized_mask_loss"]
self.images = dict()
self.images = {}
self._coverage_ratio = coverage_ratio
self._scaling = scaling
logger.debug("Initialized %s", self.__class__.__name__)
@ -630,9 +640,9 @@ class _Samples(): # pylint:disable=too-few-public-methods
A compiled preview image ready for display or saving
"""
logger.debug("Showing sample")
feeds = dict()
figures = dict()
headers = dict()
feeds = {}
figures = {}
headers = {}
for idx, side in enumerate(("a", "b")):
samples = self.images[side]
faces = samples[1]
@ -647,8 +657,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
for side, samples in self.images.items():
other_side = "a" if side == "b" else "b"
predictions = [preds["{0}_{0}".format(side)],
preds["{}_{}".format(other_side, side)]]
predictions = [preds[f"{side}_{side}"],
preds[f"{other_side}_{side}"]]
display = self._to_full_frame(side, samples, predictions)
headers[side] = self._get_headers(side, display[0].shape[1])
figures[side] = np.stack([display[0], display[1], display[2], ], axis=1)
@ -716,7 +726,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
List of :class:`numpy.ndarray` of predictions received from the model
"""
logger.debug("Getting Predictions")
preds = dict()
preds = {}
standard = self._model.model.predict([feed_a, feed_b])
swapped = self._model.model.predict([feed_b, feed_a])
@ -904,9 +914,9 @@ class _Samples(): # pylint:disable=too-few-public-methods
total_width = width * 3
logger.debug("height: %s, total_width: %s", height, total_width)
font = cv2.FONT_HERSHEY_SIMPLEX
texts = ["{} ({})".format(titles[0], side),
"{0} > {0}".format(titles[0]),
"{} > {}".format(titles[0], titles[1])]
texts = [f"{titles[0]} ({side})",
f"{titles[0]} > {titles[0]}",
f"{titles[0]} > {titles[1]}"]
scaling = (width / 144) * 0.45
text_sizes = [cv2.getTextSize(texts[idx], font, scaling, 1)[0]
for idx in range(len(texts))]
@ -1002,7 +1012,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
logger.debug("Time-lapse output set to '%s'", self._output_file)
# Rewrite paths to pull from the training images so mask and face data can be accessed
images = dict()
images = {}
for side, input_ in zip(("a", "b"), (input_a, input_b)):
training_path = os.path.dirname(self._image_paths[side][0])
images[side] = [os.path.join(training_path, os.path.basename(pth))

View File

@ -1,2 +1,2 @@
-r _requirements_base.txt
tensorflow>=2.2.0,<2.7.0
tensorflow>=2.2.0,<2.9.0

View File

@ -1,2 +1,2 @@
-r _requirements_base.txt
tensorflow-gpu>=2.2.0,<2.7.0
tensorflow-gpu>=2.2.0,<2.9.0

163
setup.py
View File

@ -19,7 +19,7 @@ INSTALL_FAILED = False
# Tensorflow builds available from pypi
TENSORFLOW_REQUIREMENTS = {">=2.2.0,<2.4.0": ["10.1", "7.6"],
">=2.4.0,<2.5.0": ["11.0", "8.0"],
">=2.5.0,<2.7.0": ["11.2", "8.1"]}
">=2.5.0,<2.9.0": ["11.2", "8.1"]}
# Mapping of Python packages to their conda names if different from pip or in non-default channel
CONDA_MAPPING = {
# "opencv-python": ("opencv", "conda-forge"), # Periodic issues with conda-forge opencv
@ -43,9 +43,9 @@ class Environment():
self.enable_amd = False
self.enable_docker = False
self.enable_cuda = False
self.required_packages = list()
self.missing_packages = list()
self.conda_missing_packages = list()
self.required_packages = []
self.missing_packages = []
self.conda_missing_packages = []
self.process_arguments()
self.check_permission()
@ -54,6 +54,7 @@ class Environment():
self.output_runtime_info()
self.check_pip()
self.upgrade_pip()
self.set_ld_library_path()
self.installed_packages = self.get_installed_packages()
self.installed_packages.update(self.get_installed_conda_packages())
@ -104,7 +105,7 @@ class Environment():
args = [arg for arg in sys.argv] # pylint:disable=unnecessary-comprehension
if self.updater:
from lib.utils import get_backend # pylint:disable=import-outside-toplevel
args.append("--{}".format(get_backend()))
args.append(f"--{get_backend()}")
for arg in args:
if arg == "--installer":
@ -124,11 +125,11 @@ class Environment():
suffix = "cpu.txt"
req_files = ["_requirements_base.txt", f"requirements_{suffix}"]
pypath = os.path.dirname(os.path.realpath(__file__))
requirements = list()
git_requirements = list()
requirements = []
git_requirements = []
for req_file in req_files:
requirements_file = os.path.join(pypath, req_file)
with open(requirements_file) as req:
with open(requirements_file, encoding="utf8") as req:
for package in req.readlines():
package = package.strip()
# parse_requirements can't handle git dependencies, so extract and then
@ -157,15 +158,14 @@ class Environment():
if not self.updater:
self.output.info("The tool provides tips for installation\n"
"and installs required python packages")
self.output.info("Setup in %s %s" % (self.os_version[0], self.os_version[1]))
self.output.info(f"Setup in {self.os_version[0]} {self.os_version[1]}")
if not self.updater and not self.os_version[0] in ["Windows", "Linux", "Darwin"]:
self.output.error("Your system %s is not supported!" % self.os_version[0])
self.output.error(f"Your system {self.os_version[0]} is not supported!")
sys.exit(1)
def check_python(self):
""" Check python and virtual environment status """
self.output.info("Installed Python: {0} {1}".format(self.py_version[0],
self.py_version[1]))
self.output.info(f"Installed Python: {self.py_version[0]} {self.py_version[1]}")
if not (self.py_version[0].split(".")[0] == "3"
and self.py_version[0].split(".")[1] in ("7", "8")
and self.py_version[1] == "64bit") and not self.updater:
@ -179,7 +179,7 @@ class Environment():
self.output.info("Running in Conda")
if self.is_virtualenv:
self.output.info("Running in a Virtual Environment")
self.output.info("Encoding: {}".format(self.encoding))
self.output.info(f"Encoding: {self.encoding}")
def check_pip(self):
""" Check installed pip version """
@ -201,17 +201,16 @@ class Environment():
if not self.is_admin and not self.is_virtualenv:
pipexe.append("--user")
pipexe.append("pip")
run(pipexe)
run(pipexe, check=True)
import pip # pylint:disable=import-outside-toplevel
pip_version = pip.__version__
self.output.info("Installed pip: {}".format(pip_version))
self.output.info(f"Installed pip: {pip_version}")
def get_installed_packages(self):
""" Get currently installed packages """
installed_packages = dict()
chk = Popen("\"{}\" -m pip freeze".format(sys.executable),
shell=True, stdout=PIPE)
installed = chk.communicate()[0].decode(self.encoding).splitlines()
installed_packages = {}
with Popen(f"\"{sys.executable}\" -m pip freeze", shell=True, stdout=PIPE) as chk:
installed = chk.communicate()[0].decode(self.encoding).splitlines()
for pkg in installed:
if "==" not in pkg:
@ -227,7 +226,7 @@ class Environment():
chk = os.popen("conda list").read()
installed = [re.sub(" +", " ", line.strip())
for line in chk.splitlines() if not line.startswith("#")]
retval = dict()
retval = {}
for pkg in installed:
item = pkg.split(" ")
retval[item[0]] = item[1]
@ -253,7 +252,7 @@ class Environment():
# that corresponds to the installed Cuda/cuDNN versions
self.required_packages = [pkg for pkg in self.required_packages
if not pkg.startswith("tensorflow-gpu")]
tf_ver = "tensorflow-gpu{}".format(tf_ver)
tf_ver = f"tensorflow-gpu{tf_ver}"
self.required_packages.append(tf_ver)
return
@ -262,13 +261,12 @@ class Environment():
"Tensorflow currently has no official prebuild for your CUDA, cuDNN "
"combination.\nEither install a combination that Tensorflow supports or "
"build and install your own tensorflow-gpu.\r\n"
"CUDA Version: {}\r\n"
"cuDNN Version: {}\r\n"
f"CUDA Version: {self.cuda_version}\r\n"
f"cuDNN Version: {self.cudnn_version}\r\n"
"Help:\n"
"Building Tensorflow: https://www.tensorflow.org/install/install_sources\r\n"
"Tensorflow supported versions: "
"https://www.tensorflow.org/install/source#tested_build_configurations".format(
self.cuda_version, self.cudnn_version))
"https://www.tensorflow.org/install/source#tested_build_configurations")
custom_tf = input("Location of custom tensorflow-gpu wheel (leave "
"blank to manually install): ")
@ -277,9 +275,9 @@ class Environment():
custom_tf = os.path.realpath(os.path.expanduser(custom_tf))
if not os.path.isfile(custom_tf):
self.output.error("{} not found".format(custom_tf))
self.output.error(f"{custom_tf} not found")
elif os.path.splitext(custom_tf)[1] != ".whl":
self.output.error("{} is not a valid pip wheel".format(custom_tf))
self.output.error(f"{custom_tf} is not a valid pip wheel")
elif custom_tf:
self.required_packages.append(custom_tf)
@ -294,9 +292,57 @@ class Environment():
config = {"backend": backend}
pypath = os.path.dirname(os.path.realpath(__file__))
config_file = os.path.join(pypath, "config", ".faceswap")
with open(config_file, "w") as cnf:
with open(config_file, "w", encoding="utf8") as cnf:
json.dump(config, cnf)
self.output.info("Faceswap config written to: {}".format(config_file))
self.output.info(f"Faceswap config written to: {config_file}")
def set_ld_library_path(self):
""" Update the LD_LIBRARY_PATH environment variable when activating a conda environment
and revert it when deactivating.
Notes
-----
From Tensorflow 2.7, installing Cuda Toolkit from conda-forge and tensorflow from pip
causes tensorflow to not be able to locate shared libs and hence not use the GPU.
We update the environment variable for all instances using Conda as it shouldn't hurt
anything and may help avoid conflicts with globally installed Cuda
"""
if not self.is_conda or not self.enable_cuda:
return
if self.os_version[0] == "Windows":
return
conda_prefix = os.environ["CONDA_PREFIX"]
activate_folder = os.path.join(conda_prefix, "etc", "conda", "activate.d")
deactivate_folder = os.path.join(conda_prefix, "etc", "conda", "deactivate.d")
os.makedirs(activate_folder, exist_ok=True)
os.makedirs(deactivate_folder, exist_ok=True)
activate_script = os.path.join(conda_prefix, activate_folder, f"env_vars.sh")
deactivate_script = os.path.join(conda_prefix, deactivate_folder, f"env_vars.sh")
if os.path.isfile(activate_script):
# Only create file if it does not already exist. There may be instances where people
# have created their own scripts, but these should be few and far between and those
# people should already know what they are doing.
return
conda_libs = os.path.join(conda_prefix, "lib")
shebang = "#!/bin/sh\n\n"
with open(activate_script, "w", encoding="utf8") as afile:
afile.write(f"{shebang}")
afile.write("export OLD_LD_LIBRARY_PATH=${LD_LIBRARY_PATH}\n")
afile.write(f"export LD_LIBRARY_PATH='{conda_libs}':${{LD_LIBRARY_PATH}}\n")
with open(deactivate_script, "w", encoding="utf8") as afile:
afile.write(f"{shebang}")
afile.write("export LD_LIBRARY_PATH=${OLD_LD_LIBRARY_PATH}\n")
afile.write("unset OLD_LD_LIBRARY_PATH\n")
self.output.info(f"Cuda search path set to '{conda_libs}'")
class Output():
@ -324,14 +370,14 @@ class Output():
""" Format INFO Text """
trm = "INFO "
if self.term_support_color:
trm = "{}INFO {} ".format(self.green, self.default_color)
trm = f"{self.green}INFO {self.default_color} "
print(trm + self.__indent_text_block(text))
def warning(self, text):
""" Format WARNING Text """
trm = "WARNING "
if self.term_support_color:
trm = "{}WARNING{} ".format(self.yellow, self.default_color)
trm = f"{self.yellow}WARNING{self.default_color} "
print(trm + self.__indent_text_block(text))
def error(self, text):
@ -339,7 +385,7 @@ class Output():
global INSTALL_FAILED # pylint:disable=global-statement
trm = "ERROR "
if self.term_support_color:
trm = "{}ERROR {} ".format(self.red, self.default_color)
trm = f"{self.red}ERROR {self.default_color} "
print(trm + self.__indent_text_block(text))
INSTALL_FAILED = True
@ -471,8 +517,8 @@ class CudaCheck(): # pylint:disable=too-few-public-methods
Initially just calls `nvcc -V` to get the installed version of Cuda currently in use.
If this fails, drills down to more OS specific checking methods.
"""
chk = Popen("nvcc -V", shell=True, stdout=PIPE, stderr=PIPE)
stdout, stderr = chk.communicate()
with Popen("nvcc -V", shell=True, stdout=PIPE, stderr=PIPE) as chk:
stdout, stderr = chk.communicate()
if not stderr:
version = re.search(r".*release (?P<cuda>\d+\.\d+)",
stdout.decode(locale.getpreferredencoding()))
@ -522,7 +568,7 @@ class CudaCheck(): # pylint:disable=too-few-public-methods
if not cudnn_checkfile:
return
found = 0
with open(cudnn_checkfile, "r") as ofile:
with open(cudnn_checkfile, "r", encoding="utf8") as ofile:
for line in ofile:
if line.lower().startswith("#define cudnn_major"):
major = line[line.rfind(" ") + 1:].strip()
@ -551,7 +597,7 @@ class CudaCheck(): # pylint:disable=too-few-public-methods
chk = os.popen("ldconfig -p | grep -P \"libcudnn.so.\\d+\" | head -n 1").read()
chk = chk.strip().replace("libcudnn.so.", "")
if not chk:
return list()
return []
cudnn_vers = chk[0]
header_files = [f"cudnn_v{cudnn_vers}.h"] + self._cudnn_header_files
@ -572,7 +618,7 @@ class CudaCheck(): # pylint:disable=too-few-public-methods
"""
# TODO A more reliable way of getting the windows location
if not self.cuda_path:
return list()
return []
scandir = os.path.join(self.cuda_path, "include")
cudnn_checkfiles = [os.path.join(scandir, header) for header in self._cudnn_header_files]
return cudnn_checkfiles
@ -701,7 +747,7 @@ class Install():
channel = None if len(pkg) != 2 else pkg[1]
pkg = pkg[0]
if version:
pkg = "{}{}".format(pkg, ",".join("".join(spec) for spec in version))
pkg = f"{pkg}{','.join(''.join(spec) for spec in version)}"
if self.env.is_conda and not pkg.startswith("git"):
if pkg.startswith("tensorflow-gpu"):
# From TF 2.4 onwards, Anaconda Tensorflow becomes a mess. The version of 2.5
@ -760,13 +806,14 @@ class Install():
package = f"\"{package}\""
condaexe.append(package)
self.output.info("Installing {}".format(package.replace("\"", "")))
clean_pkg = package.replace("\"", "")
self.output.info(f"Installing {clean_pkg}")
shell = self.env.os_version[0] == "Windows"
try:
if verbose:
run(condaexe, check=True, shell=shell)
else:
with open(os.devnull, "w") as devnull:
with open(os.devnull, "w", encoding="utf8") as devnull:
run(condaexe, stdout=devnull, stderr=devnull, check=True, shell=shell)
except CalledProcessError:
if not conda_only:
@ -809,14 +856,16 @@ class Install():
pkgs = ["cudatoolkit", "cudnn"]
shell = self.env.os_version[0] == "Windows"
for pkg in pkgs:
chk = Popen(condaexe + [pkg], shell=shell, stdout=PIPE)
available = [line.split()
for line in chk.communicate()[0].decode(self.env.encoding).splitlines()
if line.startswith(pkg)]
compatible = [req for req in available
if (pkg == "cudatoolkit" and req[1].startswith(versions[0]))
or (pkg == "cudnn" and versions[0] in req[2]
and req[1].startswith(versions[1]))]
with Popen(condaexe + [pkg], shell=shell, stdout=PIPE) as chk:
available = [line.split()
for line
in chk.communicate()[0].decode(self.env.encoding).splitlines()
if line.startswith(pkg)]
compatible = [req for req in available
if (pkg == "cudatoolkit" and req[1].startswith(versions[0]))
or (pkg == "cudnn" and versions[0] in req[2]
and req[1].startswith(versions[1]))]
candidate = "==".join(sorted(compatible, key=lambda x: x[1])[-1][:2])
self.conda_installer(candidate, verbose=True, conda_only=True)
@ -828,6 +877,8 @@ class Tips():
def docker_no_cuda(self):
""" Output Tips for Docker without Cuda """
path = os.path.dirname(os.path.realpath(__file__))
self.output.info(
"1. Install Docker\n"
"https://www.docker.com/community-edition\n\n"
@ -837,7 +888,7 @@ class Tips():
"# without GUI\n"
"docker run -tid -p 8888:8888 \\ \n"
"\t--hostname deepfakes-cpu --name deepfakes-cpu \\ \n"
"\t-v {path}:/srv \\ \n"
f"\t-v {path}:/srv \\ \n"
"\tdeepfakes-cpu\n\n"
"# with gui. tools.py gui working.\n"
"## enable local access to X11 server\n"
@ -845,7 +896,7 @@ class Tips():
"## create container\n"
"nvidia-docker run -tid -p 8888:8888 \\ \n"
"\t--hostname deepfakes-cpu --name deepfakes-cpu \\ \n"
"\t-v {path}:/srv \\ \n"
f"\t-v {path}:/srv \\ \n"
"\t-v /tmp/.X11-unix:/tmp/.X11-unix \\ \n"
"\t-e DISPLAY=unix$DISPLAY \\ \n"
"\t-e AUDIO_GID=`getent group audio | cut -d: -f3` \\ \n"
@ -854,12 +905,13 @@ class Tips():
"\t-e UID=`id -u` \\ \n"
"\tdeepfakes-cpu \n\n"
"4. Open a new terminal to run faceswap.py in /srv\n"
"docker exec -it deepfakes-cpu bash".format(
path=os.path.dirname(os.path.realpath(__file__))))
"docker exec -it deepfakes-cpu bash")
self.output.info("That's all you need to do with a docker. Have fun.")
def docker_cuda(self):
""" Output Tips for Docker wit Cuda"""
path = os.path.dirname(os.path.realpath(__file__))
self.output.info(
"1. Install Docker\n"
"https://www.docker.com/community-edition\n\n"
@ -873,7 +925,7 @@ class Tips():
"# without gui \n"
"docker run -tid -p 8888:8888 \\ \n"
"\t--hostname deepfakes-gpu --name deepfakes-gpu \\ \n"
"\t-v {path}:/srv \\ \n"
f"\t-v {path}:/srv \\ \n"
"\tdeepfakes-gpu\n\n"
"# with gui.\n"
"## enable local access to X11 server\n"
@ -883,7 +935,7 @@ class Tips():
"## create container\n"
"nvidia-docker run -tid -p 8888:8888 \\ \n"
"\t--hostname deepfakes-gpu --name deepfakes-gpu \\ \n"
"\t-v {path}:/srv \\ \n"
f"\t-v {path}:/srv \\ \n"
"\t-v /tmp/.X11-unix:/tmp/.X11-unix \\ \n"
"\t-e DISPLAY=unix$DISPLAY \\ \n"
"\t-e AUDIO_GID=`getent group audio | cut -d: -f3` \\ \n"
@ -892,8 +944,7 @@ class Tips():
"\t-e UID=`id -u` \\ \n"
"\tdeepfakes-gpu\n\n"
"6. Open a new terminal to interact with the project\n"
"docker exec deepfakes-gpu python /srv/faceswap.py gui\n".format(
path=os.path.dirname(os.path.realpath(__file__))))
"docker exec deepfakes-gpu python /srv/faceswap.py gui\n")
def macos(self):
""" Output Tips for macOS"""

View File

@ -70,41 +70,3 @@ def test_loss_wrapper(loss_func):
else:
output = output.numpy()
assert output.dtype == "float32" and not np.isnan(output)
@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()])
def test_dssim_channels_last(dummy): # pylint:disable=unused-argument
""" Basic test for DSSIM Loss """
prev_data = K.image_data_format()
K.set_image_data_format('channels_last')
for input_dim, kernel_size in zip([32, 33], [2, 3]):
input_shape = [input_dim, input_dim, 3]
var_x = np.random.random_sample(4 * input_dim * input_dim * 3)
var_x = var_x.reshape([4] + input_shape)
var_y = np.random.random_sample(4 * input_dim * input_dim * 3)
var_y = var_y.reshape([4] + input_shape)
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same', input_shape=input_shape,
activation='relu'))
model.add(Conv2D(3, (3, 3), padding='same', input_shape=input_shape,
activation='relu'))
adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
model.compile(loss=losses.DSSIMObjective(kernel_size=kernel_size),
metrics=['mse'],
optimizer=adam)
model.fit(var_x, var_y, batch_size=2, epochs=1, shuffle='batch')
# Test same
x_1 = K.constant(var_x, 'float32')
x_2 = K.constant(var_x, 'float32')
dssim = losses.DSSIMObjective(kernel_size=kernel_size)
assert_allclose(0.0, K.eval(dssim(x_1, x_2)), atol=1e-4)
# Test opposite
x_1 = K.zeros([4] + input_shape)
x_2 = K.ones([4] + input_shape)
dssim = losses.DSSIMObjective(kernel_size=kernel_size)
assert_allclose(0.5, K.eval(dssim(x_1, x_2)), atol=1e-4)
K.set_image_data_format(prev_data)

View File

@ -76,11 +76,11 @@ def _test_optimizer(optimizer, target=0.75):
@pytest.mark.parametrize("dummy", [None], ids=[get_backend().upper()])
def test_adam(dummy): # pylint:disable=unused-argument
""" Test for custom Adam optimizer """
_test_optimizer(k_optimizers.Adam(), target=0.5)
_test_optimizer(k_optimizers.Adam(decay=1e-3), target=0.5)
_test_optimizer(k_optimizers.Adam(), target=0.45)
_test_optimizer(k_optimizers.Adam(decay=1e-3), target=0.45)
@pytest.mark.parametrize("dummy", [None], ids=[get_backend().upper()])
def test_adabelief(dummy): # pylint:disable=unused-argument
""" Test for custom Adam optimizer """
_test_optimizer(optimizers.AdaBelief(), target=0.5)
_test_optimizer(optimizers.AdaBelief(), target=0.45)

View File

@ -25,5 +25,6 @@ def test_backend(dummy): # pylint:disable=unused-argument
def test_keras(dummy): # pylint:disable=unused-argument
""" Sanity check to ensure that tensorflow keras is being used for CPU and standard
keras for AMD. """
assert ((_BACKEND == "cpu" and keras.__version__ in ("2.3.0-tf", "2.4.0")) or
assert ((_BACKEND == "cpu" and keras.__version__ in ("2.3.0-tf", "2.4.0",
"2.6.0", "2.7.0", "2.8.0")) or
(_BACKEND == "amd" and keras.__version__ == "2.2.4"))