mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
Update code to support Tensorflow versions up to 2.8 (#1213)
* Update maximum tf version in setup + requirements * - bump max version of tf version in launcher - standardise tf version check * update keras get_custom_objects for tf>2.6 * bugfix: force black text in GUI file dialogs (linux) * dssim loss - Move to stock tf.ssim function * Update optimizer imports for compatibility * fix logging for tf2.8 * Fix GUI graphing for TF2.8 * update tests * bump requirements.txt versions * Remove limit on nvidia-ml-py * Graphing bugfixes - Prevent live graph from displaying if data not yet available * bugfix: Live graph. Collect loss labels correctly * fix: live graph - swallow inconsistent loss errors * Bugfix: Prevent live graph from clearing during training * Fix graphing for AMD
This commit is contained in:
parent
cda49b3c3c
commit
c1512fd41d
|
|
@ -1,19 +1,16 @@
|
|||
tqdm>=4.62
|
||||
tqdm>=4.64
|
||||
psutil>=5.8.0
|
||||
numpy>=1.18.0,<1.20.0
|
||||
opencv-python>=4.5.3.0
|
||||
pillow>=8.3.1
|
||||
scikit-learn>=0.24.2
|
||||
fastcluster>=1.1.26
|
||||
numpy>=1.18.0
|
||||
opencv-python>=4.5.5.0
|
||||
pillow>=9.0.1
|
||||
scikit-learn>=1.0.2
|
||||
fastcluster>=1.2.4
|
||||
# matplotlib 3.3.1 breaks custom toolbar in graph popup
|
||||
matplotlib>=3.2.0,<3.3.0
|
||||
imageio>=2.9.0
|
||||
imageio-ffmpeg>=0.4.5
|
||||
imageio-ffmpeg>=0.4.7
|
||||
ffmpy==0.2.3
|
||||
# Exclude badly numbered Python2 version of nvidia-ml-py
|
||||
# nvidia-ml-py>=11.450,<300
|
||||
# v11.515.0 changes dtype of output items. Pinned for now
|
||||
# TODO update code to use latest version
|
||||
nvidia-ml-py>=11.450,<11.515
|
||||
nvidia-ml-py>=11.510,<300
|
||||
pywin32>=228 ; sys_platform == "win32"
|
||||
pynvx==1.0.0 ; sys_platform == "darwin"
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from importlib import import_module
|
|||
|
||||
from lib.gpu_stats import set_exclude_devices, GPUStats
|
||||
from lib.logger import crash_log, log_setup
|
||||
from lib.utils import (FaceswapError, get_backend, KerasFinder, safe_shutdown, set_backend,
|
||||
set_system_verbosity)
|
||||
from lib.utils import (FaceswapError, get_backend, get_tf_version, KerasFinder, safe_shutdown,
|
||||
set_backend, set_system_verbosity)
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
|
|||
self._test_for_tf_version()
|
||||
self._test_for_gui()
|
||||
cmd = os.path.basename(sys.argv[0])
|
||||
src = "tools.{}".format(self._command.lower()) if cmd == "tools.py" else "scripts"
|
||||
src = f"tools.{self._command.lower()}" if cmd == "tools.py" else "scripts"
|
||||
mod = ".".join((src, self._command.lower()))
|
||||
module = import_module(mod)
|
||||
script = getattr(module, self._command.title())
|
||||
|
|
@ -53,15 +53,15 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
|
|||
Raises
|
||||
------
|
||||
FaceswapError
|
||||
If Tensorflow is not found, or is not between versions 2.2 and 2.6
|
||||
If Tensorflow is not found, or is not between versions 2.2 and 2.8
|
||||
"""
|
||||
min_ver = 2.2
|
||||
max_ver = 2.6
|
||||
max_ver = 2.8
|
||||
try:
|
||||
# Ensure tensorflow doesn't pin all threads to one core when using Math Kernel Library
|
||||
os.environ["TF_MIN_GPU_MULTIPROCESSOR_COUNT"] = "4"
|
||||
os.environ["KMP_AFFINITY"] = "disabled"
|
||||
import tensorflow as tf # pylint:disable=import-outside-toplevel
|
||||
import tensorflow as tf # noqa pylint:disable=import-outside-toplevel,unused-import
|
||||
except ImportError as err:
|
||||
if "DLL load failed while importing" in str(err):
|
||||
msg = (
|
||||
|
|
@ -77,14 +77,14 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
|
|||
f"error: {str(err)}")
|
||||
self._handle_import_error(msg)
|
||||
|
||||
tf_ver = float(".".join(tf.__version__.split(".")[:2])) # pylint:disable=no-member
|
||||
tf_ver = get_tf_version()
|
||||
if tf_ver < min_ver:
|
||||
msg = ("The minimum supported Tensorflow is version {} but you have version {} "
|
||||
"installed. Please upgrade Tensorflow.".format(min_ver, tf_ver))
|
||||
msg = (f"The minimum supported Tensorflow is version {min_ver} but you have version "
|
||||
f"{tf_ver} installed. Please upgrade Tensorflow.")
|
||||
self._handle_import_error(msg)
|
||||
if tf_ver > max_ver:
|
||||
msg = ("The maximum supported Tensorflow is version {} but you have version {} "
|
||||
"installed. Please downgrade Tensorflow.".format(max_ver, tf_ver))
|
||||
msg = (f"The maximum supported Tensorflow is version {max_ver} but you have version "
|
||||
f"{tf_ver} installed. Please downgrade Tensorflow.")
|
||||
self._handle_import_error(msg)
|
||||
logger.debug("Installed Tensorflow Version: %s", tf_ver)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@ import zlib
|
|||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.framework import errors_impl as tf_errors
|
||||
from tensorflow.core.util import event_pb2 # pylint:disable=no-name-in-module
|
||||
from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
|
||||
errors_impl as tf_errors)
|
||||
|
||||
from lib.serializer import get_serializer
|
||||
from lib.utils import get_backend
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
|
@ -43,7 +45,7 @@ class _LogFiles():
|
|||
The full path of each log file for each training session id that has been run
|
||||
"""
|
||||
logger.debug("Loading log filenames. base_dir: '%s'", self._logs_folder)
|
||||
retval = dict()
|
||||
retval = {}
|
||||
for dirpath, _, filenames in os.walk(self._logs_folder):
|
||||
if not any(filename.startswith("events.out.tfevents") for filename in filenames):
|
||||
continue
|
||||
|
|
@ -133,7 +135,7 @@ class _Cache():
|
|||
def __init__(self, session_ids):
|
||||
logger.debug("Initializing: %s: (session_ids: %s)", self.__class__.__name__, session_ids)
|
||||
self._data = {idx: None for idx in session_ids}
|
||||
self._carry_over = dict()
|
||||
self._carry_over = {}
|
||||
self._loss_labels = []
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
|
|
@ -158,18 +160,19 @@ class _Cache():
|
|||
"""
|
||||
logger.debug("Caching event data: (session_id: %s, labels: %s, data points: %s, "
|
||||
"is_live: %s)", session_id, labels, len(data), is_live)
|
||||
if not data:
|
||||
logger.debug("No data to cache")
|
||||
return
|
||||
|
||||
if labels:
|
||||
logger.debug("Setting loss labels: %s", labels)
|
||||
self._loss_labels = labels
|
||||
|
||||
if not data:
|
||||
logger.debug("No data to cache")
|
||||
return
|
||||
|
||||
timestamps, loss = self._to_numpy(data, is_live)
|
||||
|
||||
if not is_live or (is_live and not self._data.get(session_id, None)):
|
||||
self._data[session_id] = dict(labels=labels,
|
||||
self._data[session_id] = dict(labels=self._loss_labels,
|
||||
loss=zlib.compress(loss),
|
||||
loss_shape=loss.shape,
|
||||
timestamps=zlib.compress(timestamps),
|
||||
|
|
@ -207,10 +210,30 @@ class _Cache():
|
|||
for idx in sorted(data)])
|
||||
times, loss = self._process_data(data, times, loss, is_live)
|
||||
|
||||
times, loss = (np.array(times, dtype="float64"), np.array(loss, dtype="float32"))
|
||||
if is_live and not all(len(val) == len(self._loss_labels) for val in loss):
|
||||
# TODO Many attempts have been made to fix this for live graph logging, and the issue
|
||||
# of non-consistent loss record sizes keeps coming up. In the meantime we shall swallow
|
||||
# any loss values that are of incorrect length so graph remains functional. This will,
|
||||
# most likely, lead to a mismatch on iteration count so a proper fix should be
|
||||
# implemented.
|
||||
|
||||
# Timestamps and loss appears to remain consistent with each other, but sometimes loss
|
||||
# appears non-consistent. eg (lengths):
|
||||
# [2, 2, 2, 2, 2, 2, 2, 0] - last loss collection has zero length
|
||||
# [1, 2, 2, 2, 2, 2, 2, 2] - 1st loss collection has 1 length
|
||||
# [2, 2, 2, 3, 2, 2, 2] - 4th loss collection has 3 length
|
||||
|
||||
logger.debug("Inconsistent loss found in collection: %s", loss)
|
||||
for idx in reversed(range(len(loss))):
|
||||
if len(loss[idx]) != len(self._loss_labels):
|
||||
logger.debug("Removing loss/timestamps at position %s", idx)
|
||||
del loss[idx]
|
||||
del times[idx]
|
||||
|
||||
times, loss = (np.array(times, dtype="float64"), np.array(loss, dtype="float32"))
|
||||
logger.debug("Converted to numpy: (data points: %s, timestamps shape: %s, loss shape: %s)",
|
||||
len(data), times.shape, loss.shape)
|
||||
|
||||
return times, loss
|
||||
|
||||
def _collect_carry_over(self, data):
|
||||
|
|
@ -334,7 +357,7 @@ class _Cache():
|
|||
|
||||
dtype = "float32" if metric == "loss" else "float64"
|
||||
|
||||
retval = dict()
|
||||
retval = {}
|
||||
for idx, data in raw.items():
|
||||
val = {metric: np.frombuffer(zlib.decompress(data[metric]),
|
||||
dtype=dtype).reshape(data[f"{metric}_shape"])}
|
||||
|
|
@ -461,7 +484,7 @@ class TensorBoardLogs():
|
|||
and list of loss values for each step
|
||||
"""
|
||||
logger.debug("Getting loss: (session_id: %s)", session_id)
|
||||
retval = dict()
|
||||
retval = {}
|
||||
for idx in [session_id] if session_id else self.session_ids:
|
||||
self._check_cache(idx)
|
||||
data = self._cache.get_data(idx, "loss")
|
||||
|
|
@ -493,7 +516,7 @@ class TensorBoardLogs():
|
|||
|
||||
logger.debug("Getting timestamps: (session_id: %s, is_training: %s)",
|
||||
session_id, self._is_training)
|
||||
retval = dict()
|
||||
retval = {}
|
||||
for idx in [session_id] if session_id else self.session_ids:
|
||||
self._check_cache(idx)
|
||||
data = self._cache.get_data(idx, "timestamps")
|
||||
|
|
@ -565,7 +588,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
session_id: int
|
||||
The session id that the data is being cached for
|
||||
"""
|
||||
data = dict()
|
||||
data = {}
|
||||
try:
|
||||
for record in self._iterator:
|
||||
event = event_pb2.Event.FromString(record) # pylint:disable=no-member
|
||||
|
|
@ -573,8 +596,11 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
continue
|
||||
if event.summary.value[0].tag == "keras":
|
||||
self._parse_outputs(event)
|
||||
if get_backend() == "amd":
|
||||
# No model is logged for AMD so need to get loss labels from state file
|
||||
self._add_amd_loss_labels(session_id)
|
||||
if event.summary.value[0].tag.startswith("batch_"):
|
||||
data[event.step] = self._process_event(event, data.get(event.step, dict()))
|
||||
data[event.step] = self._process_event(event, data.get(event.step, {}))
|
||||
|
||||
except tf_errors.DataLossError as err:
|
||||
logger.warning("The logs for Session %s are corrupted and cannot be displayed. "
|
||||
|
|
@ -605,10 +631,6 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
config = serializer.unmarshal(struct)["config"]
|
||||
model_outputs = self._get_outputs(config)
|
||||
|
||||
# loss length of unique should be 3:
|
||||
# - decoder_both, 1, 2
|
||||
# - docoder_a, decoder_b, 1
|
||||
split_output = len(np.unique(model_outputs[..., :2])) != 3
|
||||
for side_outputs, side in zip(model_outputs, ("a", "b")):
|
||||
logger.debug("side: '%s', outputs: '%s'", side, side_outputs)
|
||||
layer_name = side_outputs[0][0]
|
||||
|
|
@ -618,8 +640,10 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
layer_outputs = self._get_outputs(output_config)
|
||||
for output in layer_outputs: # Drill into sub-model to get the actual output names
|
||||
loss_name = output[0][0]
|
||||
if not split_output: # Rename losses to reflect the side's output
|
||||
loss_name = f"{loss_name.replace('_both', '')}_{side}"
|
||||
if loss_name[-2:] not in ("_a", "_b"): # Rename losses to reflect the side output
|
||||
new_name = f"{loss_name.replace('_both', '')}_{side}"
|
||||
logger.debug("Renaming loss output from '%s' to '%s'", loss_name, new_name)
|
||||
loss_name = new_name
|
||||
if loss_name not in self._loss_labels:
|
||||
logger.debug("Adding loss name: '%s'", loss_name)
|
||||
self._loss_labels.append(loss_name)
|
||||
|
|
@ -650,6 +674,28 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
outputs, outputs.shape)
|
||||
return outputs
|
||||
|
||||
def _add_amd_loss_labels(self, session_id):
|
||||
""" It is not possible to store the model config in the Tensorboard logs for AMD so we
|
||||
need to obtain the loss labels from the model's state file. This is called now so we know
|
||||
event data is being written, and therefore the most current loss label data is available
|
||||
in the state file.
|
||||
|
||||
Loss names are added to :attr:`_loss_labels`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id: int
|
||||
The session id that the data is being cached for
|
||||
|
||||
"""
|
||||
if self._cache._loss_labels: # pylint:disable=protected-access
|
||||
return
|
||||
# Import global session here to prevent circular import
|
||||
from . import Session # pylint:disable=import-outside-toplevel
|
||||
loss_labels = sorted(Session.get_loss_keys(session_id=session_id))
|
||||
self._loss_labels = loss_labels
|
||||
logger.debug("Collated loss labels: %s", self._loss_labels)
|
||||
|
||||
@classmethod
|
||||
def _process_event(cls, event, step):
|
||||
""" Process a single Tensorflow event.
|
||||
|
|
@ -670,8 +716,19 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
The given step `dict` with the given event data added to it.
|
||||
"""
|
||||
summary = event.summary.value[0]
|
||||
|
||||
if summary.tag in ("batch_loss", "batch_total"): # Pre tf2.3 totals were "batch_total"
|
||||
step["timestamp"] = event.wall_time
|
||||
return step
|
||||
step.setdefault("loss", list()).append(summary.simple_value)
|
||||
|
||||
loss = summary.simple_value
|
||||
if not loss:
|
||||
# Need to convert a tensor to a float for TF2.8 logged data. This maybe due to change
|
||||
# in logging or may be due to work around put in place in FS training function for the
|
||||
# following bug in TF 2.8 when writing records:
|
||||
# https://github.com/keras-team/keras/issues/16173
|
||||
loss = float(tf.make_ndarray(summary.tensor))
|
||||
|
||||
step.setdefault("loss", []).append(loss)
|
||||
|
||||
return step
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from threading import Event
|
|||
import numpy as np
|
||||
|
||||
from lib.serializer import get_serializer
|
||||
from lib.utils import get_backend
|
||||
|
||||
from .event_reader import TensorBoardLogs
|
||||
|
||||
|
|
@ -62,9 +63,9 @@ class GlobalSession():
|
|||
def batch_sizes(self):
|
||||
""" dict: The batch sizes for each session_id for the model. """
|
||||
if self._state is None:
|
||||
return dict()
|
||||
return {}
|
||||
return {int(sess_id): sess["batchsize"]
|
||||
for sess_id, sess in self._state.get("sessions", dict()).items()}
|
||||
for sess_id, sess in self._state.get("sessions", {}).items()}
|
||||
|
||||
@property
|
||||
def full_summary(self):
|
||||
|
|
@ -86,7 +87,7 @@ class GlobalSession():
|
|||
|
||||
def _load_state_file(self):
|
||||
""" Load the current state file to :attr:`_state`. """
|
||||
state_file = os.path.join(self._model_dir, "{}_state.json".format(self._model_name))
|
||||
state_file = os.path.join(self._model_dir, f"{self._model_name}_state.json")
|
||||
logger.debug("Loading State: '%s'", state_file)
|
||||
serializer = get_serializer("json")
|
||||
self._state = serializer.load(state_file)
|
||||
|
|
@ -125,8 +126,7 @@ class GlobalSession():
|
|||
self._model_dir = model_folder
|
||||
self._model_name = model_name
|
||||
self._load_state_file()
|
||||
self._tb_logs = TensorBoardLogs(os.path.join(self._model_dir,
|
||||
"{}_logs".format(self._model_name)),
|
||||
self._tb_logs = TensorBoardLogs(os.path.join(self._model_dir, f"{self._model_name}_logs"),
|
||||
is_training)
|
||||
|
||||
self._summary = SessionsSummary(self)
|
||||
|
|
@ -140,7 +140,7 @@ class GlobalSession():
|
|||
|
||||
def clear(self):
|
||||
""" Clear the currently loaded session. """
|
||||
self._state = dict()
|
||||
self._state = {}
|
||||
self._model_dir = None
|
||||
self._model_name = None
|
||||
|
||||
|
|
@ -173,13 +173,13 @@ class GlobalSession():
|
|||
|
||||
loss_dict = self._tb_logs.get_loss(session_id=session_id)
|
||||
if session_id is None:
|
||||
retval = dict()
|
||||
retval = {}
|
||||
for key in sorted(loss_dict):
|
||||
for loss_key, loss in loss_dict[key].items():
|
||||
retval.setdefault(loss_key, []).extend(loss)
|
||||
retval = {key: np.array(val, dtype="float32") for key, val in retval.items()}
|
||||
else:
|
||||
retval = loss_dict.get(session_id, dict())
|
||||
retval = loss_dict.get(session_id, {})
|
||||
|
||||
if self._is_training:
|
||||
self._is_querying.clear()
|
||||
|
|
@ -239,14 +239,21 @@ class GlobalSession():
|
|||
The loss keys for the given session. If ``None`` is passed as session_id then a unique
|
||||
list of all loss keys for all sessions is returned
|
||||
"""
|
||||
loss_keys = {sess_id: list(logs.keys())
|
||||
for sess_id, logs in self._tb_logs.get_loss(session_id=session_id).items()}
|
||||
if get_backend() == "amd":
|
||||
# We can't log the graph in Tensorboard logs for AMD so need to obtain from state file
|
||||
loss_keys = {int(sess_id): [name for name in session["loss_names"] if name != "total"]
|
||||
for sess_id, session in self._state["sessions"].items()}
|
||||
else:
|
||||
loss_keys = {sess_id: list(logs.keys())
|
||||
for sess_id, logs
|
||||
in self._tb_logs.get_loss(session_id=session_id).items()}
|
||||
|
||||
if session_id is None:
|
||||
retval = list(set(loss_key
|
||||
for session in loss_keys.values()
|
||||
for loss_key in session))
|
||||
else:
|
||||
retval = loss_keys[session_id]
|
||||
retval = loss_keys.get(session_id)
|
||||
return retval
|
||||
|
||||
|
||||
|
|
@ -334,7 +341,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
if self._per_session_stats is None:
|
||||
logger.debug("Collating per session stats")
|
||||
compiled = list()
|
||||
compiled = []
|
||||
for session_id, ts_data in self._time_stats.items():
|
||||
logger.debug("Compiling session ID: %s", session_id)
|
||||
if self._state is None:
|
||||
|
|
@ -446,15 +453,15 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
|||
retval = []
|
||||
for summary in compiled_stats:
|
||||
hrs, mins, secs = self._convert_time(summary["elapsed"])
|
||||
stats = dict()
|
||||
stats = {}
|
||||
for key in summary:
|
||||
if key not in ("start", "end", "elapsed", "rate"):
|
||||
stats[key] = summary[key]
|
||||
continue
|
||||
stats["start"] = time.strftime("%x %X", time.localtime(summary["start"]))
|
||||
stats["end"] = time.strftime("%x %X", time.localtime(summary["end"]))
|
||||
stats["elapsed"] = "{}:{}:{}".format(hrs, mins, secs)
|
||||
stats["rate"] = "{0:.1f}".format(summary["rate"])
|
||||
stats["elapsed"] = f"{hrs}:{mins}:{secs}"
|
||||
stats["rate"] = f"{summary['rate']:.1f}"
|
||||
retval.append(stats)
|
||||
return retval
|
||||
|
||||
|
|
@ -474,9 +481,9 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
hrs = int(timestamp // 3600)
|
||||
if hrs < 10:
|
||||
hrs = "{0:02d}".format(hrs)
|
||||
mins = "{0:02d}".format((int(timestamp % 3600) // 60))
|
||||
secs = "{0:02d}".format((int(timestamp % 3600) % 60))
|
||||
hrs = f"{hrs:02d}"
|
||||
mins = f"{(int(timestamp % 3600) // 60):02d}"
|
||||
secs = f"{(int(timestamp % 3600) % 60):02d}"
|
||||
return hrs, mins, secs
|
||||
|
||||
|
||||
|
|
@ -529,7 +536,7 @@ class Calculations():
|
|||
self._iterations = 0
|
||||
self._limit = 0
|
||||
self._start_iteration = 0
|
||||
self._stats = dict()
|
||||
self._stats = {}
|
||||
self.refresh()
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
|
|
@ -630,7 +637,7 @@ class Calculations():
|
|||
if self._args["flatten_outliers"]:
|
||||
loss = self._flatten_outliers(loss)
|
||||
|
||||
self.stats["raw_{}".format(loss_name)] = loss
|
||||
self.stats[f"raw_{loss_name}"] = loss
|
||||
|
||||
self._iterations = 0 if not iterations else min(iterations)
|
||||
if self._limit > 1:
|
||||
|
|
@ -642,7 +649,7 @@ class Calculations():
|
|||
if len(iterations) > 1:
|
||||
# Crop all losses to the same number of items
|
||||
if self._iterations == 0:
|
||||
self.stats = {lossname: np.array(list(), dtype=loss.dtype)
|
||||
self.stats = {lossname: np.array([], dtype=loss.dtype)
|
||||
for lossname, loss in self.stats.items()}
|
||||
else:
|
||||
self.stats = {lossname: loss[:self._iterations]
|
||||
|
|
@ -722,7 +729,7 @@ class Calculations():
|
|||
logger.debug("Calculating totals rate")
|
||||
batchsizes = _SESSION.batch_sizes
|
||||
total_timestamps = _SESSION.get_timestamps(None)
|
||||
rate = list()
|
||||
rate = []
|
||||
for sess_id in sorted(total_timestamps.keys()):
|
||||
batchsize = batchsizes[sess_id]
|
||||
timestamps = total_timestamps[sess_id]
|
||||
|
|
@ -737,10 +744,10 @@ class Calculations():
|
|||
if selection == "raw":
|
||||
continue
|
||||
logger.debug("Calculating: %s", selection)
|
||||
method = getattr(self, "_calc_{}".format(selection))
|
||||
method = getattr(self, f"_calc_{selection}")
|
||||
raw_keys = [key for key in self._stats if key.startswith("raw_")]
|
||||
for key in raw_keys:
|
||||
selected_key = "{}_{}".format(selection, key.replace("raw_", ""))
|
||||
selected_key = f"{selection}_{key.replace('raw_', '')}"
|
||||
self._stats[selected_key] = method(self._stats[key])
|
||||
|
||||
def _calc_avg(self, data):
|
||||
|
|
@ -866,7 +873,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
|
|||
optimizations.
|
||||
"""
|
||||
# Use :func:`np.finfo(dtype).eps` if you are worried about accuracy and want to be safe.
|
||||
epsilon = np.finfo(self._dtype).tiny
|
||||
epsilon = np.finfo(self._dtype).tiny # pylint:disable=no-member
|
||||
# If this produces an OverflowError, make epsilon larger:
|
||||
retval = int(np.log(epsilon) / np.log(1 - self._alpha)) + 1
|
||||
logger.debug("row_size: %s", retval)
|
||||
|
|
|
|||
|
|
@ -63,13 +63,10 @@ class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
return
|
||||
filename = "extract_convert_preview"
|
||||
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = os.path.join(location,
|
||||
"{}_{}.{}".format(filename,
|
||||
now,
|
||||
"png"))
|
||||
filename = os.path.join(location, f"{filename}_{now}.png")
|
||||
get_images().previewoutput[0].save(filename)
|
||||
logger.debug("Saved preview to %s", filename)
|
||||
print("Saved preview to {}".format(filename))
|
||||
print(f"Saved preview to {filename}")
|
||||
|
||||
|
||||
class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
||||
|
|
@ -125,7 +122,7 @@ class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
should_update = self.update_preview.get()
|
||||
|
||||
for name in sortednames:
|
||||
if name not in existing.keys():
|
||||
if name not in existing:
|
||||
self.add_child(name)
|
||||
elif should_update:
|
||||
tab_id = existing[name]
|
||||
|
|
@ -197,19 +194,16 @@ class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
""" Save the figure to file """
|
||||
filename = self.name
|
||||
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = os.path.join(location,
|
||||
"{}_{}.{}".format(filename,
|
||||
now,
|
||||
"png"))
|
||||
filename = os.path.join(location, f"{filename}_{now}.png")
|
||||
get_images().previewtrain[self.name][0].save(filename)
|
||||
logger.debug("Saved preview to %s", filename)
|
||||
print("Saved preview to {}".format(filename))
|
||||
print(f"Saved preview to {filename}")
|
||||
|
||||
|
||||
class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
||||
""" The Graph Tab of the Display section """
|
||||
def __init__(self, parent, tab_name, helptext, wait_time, command=None):
|
||||
self._trace_vars = dict()
|
||||
self._trace_vars = {}
|
||||
super().__init__(parent, tab_name, helptext, wait_time, command)
|
||||
|
||||
def set_vars(self):
|
||||
|
|
@ -370,6 +364,8 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
logger.trace("Loading graph")
|
||||
self.display_item = Session
|
||||
self._add_trace_variables()
|
||||
elif Session.is_training and self.display_item is not None:
|
||||
logger.trace("Graph already displayed. Nothing to do.")
|
||||
else:
|
||||
logger.trace("Clearing graph")
|
||||
self.display_item = None
|
||||
|
|
@ -384,9 +380,15 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
|
||||
logger.debug("Adding graph")
|
||||
existing = list(self.subnotebook_get_titles_ids().keys())
|
||||
loss_keys = [key
|
||||
for key in self.display_item.get_loss_keys(Session.session_ids[-1])
|
||||
if key != "total"]
|
||||
|
||||
loss_keys = self.display_item.get_loss_keys(Session.session_ids[-1])
|
||||
if not loss_keys:
|
||||
# Reload if we attempt to get loss keys before data is written
|
||||
logger.debug("Waiting for Session Data to become available to graph")
|
||||
self.after(1000, self.display_item_process)
|
||||
return
|
||||
|
||||
loss_keys = [key for key in loss_keys if key != "total"]
|
||||
display_tabs = sorted(set(key[:-1].rstrip("_") for key in loss_keys))
|
||||
|
||||
for loss_key in display_tabs:
|
||||
|
|
@ -472,7 +474,7 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
for name, (var, trace) in self._trace_vars.items():
|
||||
logger.debug("Clearing trace from variable: %s", name)
|
||||
var.trace_vdelete("w", trace)
|
||||
self._trace_vars = dict()
|
||||
self._trace_vars = {}
|
||||
|
||||
def close(self):
|
||||
""" Clear the plots from RAM """
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class DisplayPage(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
@staticmethod
|
||||
def set_vars():
|
||||
""" Override to return a dict of page specific variables """
|
||||
return dict()
|
||||
return {}
|
||||
|
||||
def on_tab_select(self): # pylint:disable=no-self-use
|
||||
""" Override for specific actions when the current tab is selected """
|
||||
|
|
@ -151,7 +151,7 @@ class DisplayPage(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
|
||||
def subnotebook_get_titles_ids(self):
|
||||
""" Return tabs ids and titles """
|
||||
tabs = dict()
|
||||
tabs = {}
|
||||
for tab_id in range(0, self.subnotebook.index("end")):
|
||||
tabs[self.subnotebook.tab(tab_id, "text")] = tab_id
|
||||
logger.debug(tabs)
|
||||
|
|
@ -213,11 +213,11 @@ class DisplayOptionalPage(DisplayPage): # pylint: disable=too-many-ancestors
|
|||
def set_info_text(self):
|
||||
""" Set waiting for display text """
|
||||
if not self.vars["enabled"].get():
|
||||
msg = "{} disabled".format(self.tabname.title())
|
||||
msg = f"{self.tabname.title()} disabled"
|
||||
elif self.vars["enabled"].get() and not self.vars["ready"].get():
|
||||
msg = "Waiting for {}...".format(self.tabname)
|
||||
msg = f"Waiting for {self.tabname}..."
|
||||
else:
|
||||
msg = "Displaying {}".format(self.tabname)
|
||||
msg = f"Displaying {self.tabname}"
|
||||
logger.debug(msg)
|
||||
self.set_info(msg)
|
||||
|
||||
|
|
@ -235,7 +235,7 @@ class DisplayOptionalPage(DisplayPage): # pylint: disable=too-many-ancestors
|
|||
command=self.save_items)
|
||||
btnsave.pack(padx=2, side=tk.RIGHT)
|
||||
Tooltip(btnsave,
|
||||
text=_("Save {}(s) to file").format(self.tabname),
|
||||
text=_(f"Save {self.tabname}(s) to file"),
|
||||
wrap_length=200)
|
||||
|
||||
def add_option_enable(self):
|
||||
|
|
@ -243,11 +243,11 @@ class DisplayOptionalPage(DisplayPage): # pylint: disable=too-many-ancestors
|
|||
logger.debug("Adding enable option")
|
||||
chkenable = ttk.Checkbutton(self.optsframe,
|
||||
variable=self.vars["enabled"],
|
||||
text="Enable {}".format(self.tabname),
|
||||
text=f"Enable {self.tabname}",
|
||||
command=self.on_chkenable_change)
|
||||
chkenable.pack(side=tk.RIGHT, padx=5, anchor=tk.W)
|
||||
Tooltip(chkenable,
|
||||
text=_("Enable or disable {} display").format(self.tabname),
|
||||
text=_(f"Enable or disable {self.tabname} display"),
|
||||
wrap_length=200)
|
||||
|
||||
def save_items(self):
|
||||
|
|
|
|||
|
|
@ -137,6 +137,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
"variable: %s)", self.__class__.__name__, handle_type, file_type, title,
|
||||
initial_folder, initial_file, command, action, variable)
|
||||
self._handletype = handle_type
|
||||
self._dummy_master = self._set_dummy_master()
|
||||
self._defaults = self._set_defaults()
|
||||
self._kwargs = self._set_kwargs(title,
|
||||
initial_folder,
|
||||
|
|
@ -145,7 +146,9 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
command,
|
||||
action,
|
||||
variable)
|
||||
self.return_file = getattr(self, "_{}".format(self._handletype.lower()))()
|
||||
self.return_file = getattr(self, f"_{self._handletype.lower()}")()
|
||||
self._remove_dummy_master()
|
||||
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
|
|
@ -184,10 +187,10 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
if platform.system() == "Linux":
|
||||
filetypes[key] = [item
|
||||
if item[0] == "All files"
|
||||
else (item[0], "{} {}".format(item[1], item[1].upper()))
|
||||
else (item[0], f"{item[1]} {item[1].upper()}")
|
||||
for item in filetypes[key]]
|
||||
if len(filetypes[key]) > 2:
|
||||
multi = ["{} Files".format(key.title())]
|
||||
multi = [f"{key.title()} Files"]
|
||||
multi.append(" ".join([ftype[1]
|
||||
for ftype in filetypes[key] if ftype[0] != "All files"]))
|
||||
filetypes[key].insert(0, tuple(multi))
|
||||
|
|
@ -214,6 +217,35 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
"rotate": "save_filename",
|
||||
"slice": "save_filename"}))
|
||||
|
||||
@classmethod
|
||||
def _set_dummy_master(cls):
|
||||
""" Add an option to force black font on Linux file dialogs KDE issue that displays light
|
||||
font on white background).
|
||||
|
||||
This is a pretty hacky solution, but tkinter does not allow direct editing of file dialogs,
|
||||
so we create a dummy frame and add the foreground option there, so that the file dialog can
|
||||
inherit the foreground.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tkinter.Frame or ``None``
|
||||
The dummy master frame for Linux systems, otherwise ``None``
|
||||
"""
|
||||
if platform.system().lower() == "linux":
|
||||
retval = tk.Frame()
|
||||
retval.option_add("*foreground", "black")
|
||||
else:
|
||||
retval = None
|
||||
return retval
|
||||
|
||||
def _remove_dummy_master(self):
|
||||
""" Destroy the dummy master widget on Linux systems. """
|
||||
if platform.system().lower() != "linux":
|
||||
return
|
||||
self._dummy_master.destroy()
|
||||
del self._dummy_master
|
||||
self._dummy_master = None
|
||||
|
||||
def _set_defaults(self):
|
||||
""" Set the default file type for the file dialog. Generally the first found file type
|
||||
will be used, but this is overridden if it is not appropriate.
|
||||
|
|
@ -264,7 +296,9 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Setting Kwargs: (title: %s, initial_folder: %s, initial_file: '%s', "
|
||||
"file_type: '%s', command: '%s': action: '%s', variable: '%s')",
|
||||
title, initial_folder, initial_file, file_type, command, action, variable)
|
||||
kwargs = dict()
|
||||
|
||||
kwargs = dict(master=self._dummy_master)
|
||||
|
||||
if self._handletype.lower() == "context":
|
||||
self._set_context_handletype(command, action, variable)
|
||||
|
||||
|
|
@ -361,10 +395,10 @@ class Images():
|
|||
self._pathpreview = os.path.join(PATHCACHE, "preview")
|
||||
self._pathoutput = None
|
||||
self._previewoutput = None
|
||||
self._previewtrain = dict()
|
||||
self._previewtrain = {}
|
||||
self._previewcache = dict(modified=None, # cache for extract and convert
|
||||
images=None,
|
||||
filenames=list(),
|
||||
filenames=[],
|
||||
placeholder=None)
|
||||
self._errcount = 0
|
||||
self._icons = self._load_icons()
|
||||
|
|
@ -420,7 +454,7 @@ class Images():
|
|||
"""
|
||||
size = get_config().user_config_dict.get("icon_size", 16)
|
||||
size = int(round(size * get_config().scaling_factor))
|
||||
icons = dict()
|
||||
icons = {}
|
||||
pathicons = os.path.join(PATHCACHE, "icons")
|
||||
for fname in os.listdir(pathicons):
|
||||
name, ext = os.path.splitext(fname)
|
||||
|
|
@ -470,10 +504,10 @@ class Images():
|
|||
logger.debug("Clearing image cache")
|
||||
self._pathoutput = None
|
||||
self._previewoutput = None
|
||||
self._previewtrain = dict()
|
||||
self._previewtrain = {}
|
||||
self._previewcache = dict(modified=None, # cache for extract and convert
|
||||
images=None,
|
||||
filenames=list(),
|
||||
filenames=[],
|
||||
placeholder=None)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -600,10 +634,10 @@ class Images():
|
|||
logger.debug("num_images: %s", num_images)
|
||||
if num_images == 0:
|
||||
return False
|
||||
samples = list()
|
||||
samples = []
|
||||
start_idx = len(image_files) - num_images if len(image_files) > num_images else 0
|
||||
show_files = sorted(image_files, key=os.path.getctime)[start_idx:]
|
||||
dropped_files = list()
|
||||
dropped_files = []
|
||||
for fname in show_files:
|
||||
try:
|
||||
img = Image.open(fname)
|
||||
|
|
@ -732,7 +766,7 @@ class Images():
|
|||
modified = None
|
||||
if not image_files:
|
||||
logger.debug("No preview to display")
|
||||
self._previewtrain = dict()
|
||||
self._previewtrain = {}
|
||||
return
|
||||
for img in image_files:
|
||||
modified = os.path.getmtime(img) if modified is None else modified
|
||||
|
|
@ -755,7 +789,7 @@ class Images():
|
|||
self._errcount += 1
|
||||
else:
|
||||
logger.error("Error reading the preview file for '%s'", img)
|
||||
print("Error reading the preview file for {}".format(name))
|
||||
print(f"Error reading the preview file for {name}")
|
||||
self._previewtrain[name] = None
|
||||
|
||||
def _get_current_size(self, name):
|
||||
|
|
@ -1126,7 +1160,7 @@ class Config():
|
|||
Additional text to be appended to the GUI title bar. Default: ``None``
|
||||
"""
|
||||
title = "Faceswap.py"
|
||||
title += " - {}".format(text) if text is not None and text else ""
|
||||
title += f" - {text}" if text is not None and text else ""
|
||||
self.root.title(title)
|
||||
|
||||
def set_geometry(self, width, height, fullscreen=False):
|
||||
|
|
@ -1154,8 +1188,7 @@ class Config():
|
|||
elif fullscreen:
|
||||
self.root.attributes('-zoomed', True)
|
||||
else:
|
||||
self.root.geometry("{}x{}+80+80".format(str(initial_dimensions[0]),
|
||||
str(initial_dimensions[1])))
|
||||
self.root.geometry(f"{str(initial_dimensions[0])}x{str(initial_dimensions[1])}+80+80")
|
||||
logger.debug("Geometry: %sx%s", *initial_dimensions)
|
||||
|
||||
|
||||
|
|
@ -1260,7 +1293,7 @@ class PreviewTrigger():
|
|||
"""
|
||||
trigger = self._trigger_files[trigger_type]
|
||||
if not os.path.isfile(trigger):
|
||||
with open(trigger, "w"):
|
||||
with open(trigger, "w", encoding="utf8"):
|
||||
pass
|
||||
logger.debug("Set preview trigger: %s", trigger)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,9 +9,8 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
from keras import backend as K
|
||||
from keras import initializers
|
||||
from keras.utils import get_custom_objects
|
||||
|
||||
from lib.utils import get_backend
|
||||
from lib.utils import get_backend, get_keras_custom_objects as get_custom_objects
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
|
@ -64,7 +63,7 @@ def compute_fans(shape, data_format='channels_last'):
|
|||
return fan_in, fan_out
|
||||
|
||||
|
||||
class ICNR(initializers.Initializer): # pylint: disable=invalid-name
|
||||
class ICNR(initializers.Initializer): # pylint: disable=invalid-name,no-member
|
||||
""" ICNR initializer for checkerboard artifact free sub pixel convolution
|
||||
|
||||
Parameters
|
||||
|
|
@ -167,11 +166,11 @@ class ICNR(initializers.Initializer): # pylint: disable=invalid-name
|
|||
config = {"scale": self.scale,
|
||||
"initializer": self.initializer
|
||||
}
|
||||
base_config = super(ICNR, self).get_config()
|
||||
base_config = super().get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class ConvolutionAware(initializers.Initializer):
|
||||
class ConvolutionAware(initializers.Initializer): # pylint: disable=no-member
|
||||
"""
|
||||
Initializer that generates orthogonal convolution filters in the Fourier space. If this
|
||||
initializer is passed a shape that is not 3D or 4D, orthogonal initialization will be used.
|
||||
|
|
@ -204,8 +203,8 @@ class ConvolutionAware(initializers.Initializer):
|
|||
def __init__(self, eps_std=0.05, seed=None, initialized=False):
|
||||
self.eps_std = eps_std
|
||||
self.seed = seed
|
||||
self.orthogonal = initializers.Orthogonal()
|
||||
self.he_uniform = initializers.he_uniform()
|
||||
self.orthogonal = initializers.Orthogonal() # pylint:disable=no-member
|
||||
self.he_uniform = initializers.he_uniform() # pylint:disable=no-member
|
||||
self.initialized = initialized
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
|
|
|
|||
|
|
@ -10,16 +10,15 @@ import tensorflow as tf
|
|||
import keras.backend as K
|
||||
|
||||
from keras.layers import InputSpec, Layer
|
||||
from keras.utils import get_custom_objects
|
||||
|
||||
from lib.utils import get_backend
|
||||
from lib.utils import get_backend, get_keras_custom_objects as get_custom_objects
|
||||
|
||||
if get_backend() == "amd":
|
||||
from lib.plaidml_utils import pad
|
||||
from keras.utils import conv_utils # pylint:disable=ungrouped-imports
|
||||
else:
|
||||
from tensorflow import pad
|
||||
from tensorflow.python.keras.utils import conv_utils
|
||||
from tensorflow.python.keras.utils import conv_utils # pylint:disable=no-name-in-module
|
||||
|
||||
|
||||
class PixelShuffler(Layer):
|
||||
|
|
@ -64,20 +63,22 @@ class PixelShuffler(Layer):
|
|||
def __init__(self, size=(2, 2), data_format=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if get_backend() == "amd":
|
||||
self.data_format = K.normalize_data_format(data_format)
|
||||
self.data_format = K.normalize_data_format(data_format) # pylint:disable=no-member
|
||||
else:
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
self.size = conv_utils.normalize_tuple(size, 2, 'size')
|
||||
|
||||
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
"""This is where the layer's logic lives.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: tensor
|
||||
Input tensor, or list/tuple of input tensors
|
||||
args: tuple
|
||||
Additional standard keras Layer arguments
|
||||
kwargs: dict
|
||||
Additional keyword arguments. Unused
|
||||
Additional standard keras Layer keyword arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -186,7 +187,7 @@ class PixelShuffler(Layer):
|
|||
"""
|
||||
config = {'size': self.size,
|
||||
'data_format': self.data_format}
|
||||
base_config = super(PixelShuffler, self).get_config()
|
||||
base_config = super().get_config()
|
||||
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
|
@ -208,15 +209,17 @@ class KResizeImages(Layer):
|
|||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
""" Call the upsample layer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: tensor
|
||||
Input tensor, or list/tuple of input tensors
|
||||
args: tuple
|
||||
Additional standard keras Layer arguments
|
||||
kwargs: dict
|
||||
Additional keyword arguments. Unused
|
||||
Additional standard keras Layer keyword arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -316,11 +319,11 @@ class SubPixelUpscaling(Layer):
|
|||
"""
|
||||
|
||||
def __init__(self, scale_factor=2, data_format=None, **kwargs):
|
||||
super(SubPixelUpscaling, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
if get_backend() == "amd":
|
||||
self.data_format = K.normalize_data_format(data_format)
|
||||
self.data_format = K.normalize_data_format(data_format) # pylint:disable=no-member
|
||||
else:
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
|
||||
|
|
@ -337,15 +340,17 @@ class SubPixelUpscaling(Layer):
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
"""This is where the layer's logic lives.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: tensor
|
||||
Input tensor, or list/tuple of input tensors
|
||||
args: tuple
|
||||
Additional standard keras Layer arguments
|
||||
kwargs: dict
|
||||
Additional keyword arguments. Unused
|
||||
Additional standard keras Layer keyword arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -460,7 +465,7 @@ class SubPixelUpscaling(Layer):
|
|||
"""
|
||||
config = {"scale_factor": self.scale_factor,
|
||||
"data_format": self.data_format}
|
||||
base_config = super(SubPixelUpscaling, self).get_config()
|
||||
base_config = super().get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
|
|
@ -592,7 +597,7 @@ class ReflectionPadding2D(Layer):
|
|||
"""
|
||||
config = {'stride': self.stride,
|
||||
'kernel_size': self.kernel_size}
|
||||
base_config = super(ReflectionPadding2D, self).get_config()
|
||||
base_config = super().get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
|
|
@ -602,9 +607,9 @@ class _GlobalPooling2D(Layer):
|
|||
From keras as access to pooling is trickier in tensorflow.keras
|
||||
"""
|
||||
def __init__(self, data_format=None, **kwargs):
|
||||
super(_GlobalPooling2D, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
if get_backend() == "amd":
|
||||
self.data_format = K.normalize_data_format(data_format)
|
||||
self.data_format = K.normalize_data_format(data_format) # pylint:disable=no-member
|
||||
else:
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
self.input_spec = InputSpec(ndim=4)
|
||||
|
|
@ -621,37 +626,41 @@ class _GlobalPooling2D(Layer):
|
|||
return (input_shape[0], input_shape[3])
|
||||
return (input_shape[0], input_shape[1])
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
""" Override to call the layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: Tensor
|
||||
The input to the layer
|
||||
args: tuple
|
||||
Additional standard keras Layer arguments
|
||||
kwargs: dict
|
||||
Additional keyword arguments
|
||||
Additional standard keras Layer keyword arguments
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_config(self):
|
||||
""" Set the Keras config """
|
||||
config = {'data_format': self.data_format}
|
||||
base_config = super(_GlobalPooling2D, self).get_config()
|
||||
base_config = super().get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class GlobalMinPooling2D(_GlobalPooling2D):
|
||||
"""Global minimum pooling operation for spatial data. """
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
"""This is where the layer's logic lives.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: tensor
|
||||
Input tensor, or list/tuple of input tensors
|
||||
args: tuple
|
||||
Additional standard keras Layer arguments
|
||||
kwargs: dict
|
||||
Additional keyword arguments
|
||||
Additional standard keras Layer keyword arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -668,15 +677,17 @@ class GlobalMinPooling2D(_GlobalPooling2D):
|
|||
class GlobalStdDevPooling2D(_GlobalPooling2D):
|
||||
"""Global standard deviation pooling operation for spatial data. """
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
"""This is where the layer's logic lives.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: tensor
|
||||
Input tensor, or list/tuple of input tensors
|
||||
args: tuple
|
||||
Additional standard keras Layer arguments
|
||||
kwargs: dict
|
||||
Additional keyword arguments
|
||||
Additional standard keras Layer keyword arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -702,7 +713,7 @@ class L2_normalize(Layer): # pylint:disable=invalid-name
|
|||
"""
|
||||
def __init__(self, axis, **kwargs):
|
||||
self.axis = axis
|
||||
super(L2_normalize, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def call(self, inputs): # pylint:disable=arguments-differ
|
||||
"""This is where the layer's logic lives.
|
||||
|
|
@ -736,7 +747,7 @@ class L2_normalize(Layer): # pylint:disable=invalid-name
|
|||
dict
|
||||
A python dictionary containing the layer configuration
|
||||
"""
|
||||
config = super(L2_normalize, self).get_config()
|
||||
config = super().get_config()
|
||||
config["axis"] = self.axis
|
||||
return config
|
||||
|
||||
|
|
|
|||
|
|
@ -7,17 +7,17 @@ import logging
|
|||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.engine import compile_utils
|
||||
from tensorflow.python.keras.engine import compile_utils # pylint:disable=no-name-in-module
|
||||
|
||||
from keras import backend as K
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
||||
|
||||
|
||||
class DSSIMObjective(tf.keras.losses.Loss):
|
||||
class DSSIMObjective(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
|
||||
""" DSSIM Loss Function
|
||||
|
||||
Difference of Structural Similarity (DSSIM loss function). Clipped between 0 and 0.5
|
||||
Difference of Structural Similarity (DSSIM loss function).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -25,66 +25,24 @@ class DSSIMObjective(tf.keras.losses.Loss):
|
|||
Parameter of the SSIM. Default: `0.01`
|
||||
k_2: float, optional
|
||||
Parameter of the SSIM. Default: `0.03`
|
||||
kernel_size: int, optional
|
||||
Size of the sliding window Default: `3`
|
||||
filter_size: int, optional
|
||||
size of gaussian filter Default: `11`
|
||||
filter_sigma: float, optional
|
||||
Width of gaussian filter Default: `1.5`
|
||||
max_value: float, optional
|
||||
Max value of the output. Default: `1.0`
|
||||
|
||||
Notes
|
||||
------
|
||||
You should add a regularization term like a l2 loss in addition to this one.
|
||||
|
||||
References
|
||||
----------
|
||||
https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2017 Fariz Rahman
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
def __init__(self, k_1=0.01, k_2=0.03, kernel_size=3, max_value=1.0):
|
||||
def __init__(self, k_1=0.01, k_2=0.03, filter_size=11, filter_sigma=1.5, max_value=1.0):
|
||||
super().__init__(name="DSSIMObjective")
|
||||
self.kernel_size = kernel_size
|
||||
self.filter_size = filter_size
|
||||
self.filter_sigma = filter_sigma
|
||||
self.k_1 = k_1
|
||||
self.k_2 = k_2
|
||||
self.max_value = max_value
|
||||
self.c_1 = (self.k_1 * self.max_value) ** 2
|
||||
self.c_2 = (self.k_2 * self.max_value) ** 2
|
||||
self.dim_ordering = K.image_data_format()
|
||||
|
||||
@staticmethod
|
||||
def _int_shape(input_tensor):
|
||||
""" Returns the shape of tensor or variable as a tuple of int or None entries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_tensor: tensor or variable
|
||||
The input to return the shape for
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple
|
||||
A tuple of integers (or None entries)
|
||||
"""
|
||||
return K.int_shape(input_tensor)
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
""" Call the DSSIM Loss Function.
|
||||
|
|
@ -100,104 +58,19 @@ class DSSIMObjective(tf.keras.losses.Loss):
|
|||
-------
|
||||
tensor
|
||||
The DSSIM Loss value
|
||||
|
||||
Notes
|
||||
-----
|
||||
There are additional parameters for this function. some of the 'modes' for edge behavior
|
||||
do not yet have a gradient definition in the Theano tree and cannot be used for learning
|
||||
"""
|
||||
|
||||
kernel = [self.kernel_size, self.kernel_size]
|
||||
y_true = K.reshape(y_true, [-1] + list(self._int_shape(y_pred)[1:]))
|
||||
y_pred = K.reshape(y_pred, [-1] + list(self._int_shape(y_pred)[1:]))
|
||||
patches_pred = self.extract_image_patches(y_pred,
|
||||
kernel,
|
||||
kernel,
|
||||
'valid',
|
||||
self.dim_ordering)
|
||||
patches_true = self.extract_image_patches(y_true,
|
||||
kernel,
|
||||
kernel,
|
||||
'valid',
|
||||
self.dim_ordering)
|
||||
|
||||
# Get mean
|
||||
u_true = K.mean(patches_true, axis=-1)
|
||||
u_pred = K.mean(patches_pred, axis=-1)
|
||||
# Get variance
|
||||
var_true = K.var(patches_true, axis=-1)
|
||||
var_pred = K.var(patches_pred, axis=-1)
|
||||
# Get standard deviation
|
||||
covar_true_pred = K.mean(
|
||||
patches_true * patches_pred, axis=-1) - u_true * u_pred
|
||||
|
||||
ssim = (2 * u_true * u_pred + self.c_1) * (
|
||||
2 * covar_true_pred + self.c_2)
|
||||
denom = (K.square(u_true) + K.square(u_pred) + self.c_1) * (
|
||||
var_pred + var_true + self.c_2)
|
||||
ssim /= denom # no need for clipping, c_1 + c_2 make the denorm non-zero
|
||||
return (1.0 - ssim) / 2.0
|
||||
|
||||
@staticmethod
|
||||
def _preprocess_padding(padding):
|
||||
"""Convert keras padding to tensorflow padding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
padding: string,
|
||||
`"same"` or `"valid"`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
`"SAME"` or `"VALID"`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `padding` is invalid.
|
||||
"""
|
||||
if padding == 'same':
|
||||
padding = 'SAME'
|
||||
elif padding == 'valid':
|
||||
padding = 'VALID'
|
||||
else:
|
||||
raise ValueError('Invalid padding:', padding)
|
||||
return padding
|
||||
|
||||
def extract_image_patches(self, input_tensor, k_sizes, s_sizes,
|
||||
padding='same', data_format='channels_last'):
|
||||
""" Extract the patches from an image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_tensor: tensor
|
||||
The input image
|
||||
k_sizes: tuple
|
||||
2-d tuple with the kernel size
|
||||
s_sizes: tuple
|
||||
2-d tuple with the strides size
|
||||
padding: str, optional
|
||||
`"same"` or `"valid"`. Default: `"same"`
|
||||
data_format: str, optional.
|
||||
`"channels_last"` or `"channels_first"`. Default: `"channels_last"`
|
||||
|
||||
Returns
|
||||
-------
|
||||
The (k_w, k_h) patches extracted
|
||||
Tensorflow ==> (batch_size, w, h, k_w, k_h, c)
|
||||
Theano ==> (batch_size, w, h, c, k_w, k_h)
|
||||
"""
|
||||
kernel = [1, k_sizes[0], k_sizes[1], 1]
|
||||
strides = [1, s_sizes[0], s_sizes[1], 1]
|
||||
padding = self._preprocess_padding(padding)
|
||||
if data_format == 'channels_first':
|
||||
input_tensor = K.permute_dimensions(input_tensor, (0, 2, 3, 1))
|
||||
patches = tf.image.extract_patches(input_tensor, kernel, strides, [1, 1, 1, 1], padding)
|
||||
return patches
|
||||
ssim = tf.image.ssim(y_true,
|
||||
y_pred,
|
||||
self.max_value,
|
||||
filter_size=self.filter_size,
|
||||
filter_sigma=self.filter_sigma,
|
||||
k1=self.k_1,
|
||||
k2=self.k_2)
|
||||
dssim_loss = 1. - ssim
|
||||
return dssim_loss
|
||||
|
||||
|
||||
class GeneralizedLoss(tf.keras.losses.Loss):
|
||||
class GeneralizedLoss(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
|
||||
""" Generalized function used to return a large variety of mathematical loss functions.
|
||||
|
||||
The primary benefit is a smooth, differentiable version of L1 loss.
|
||||
|
|
@ -247,10 +120,11 @@ class GeneralizedLoss(tf.keras.losses.Loss):
|
|||
return loss
|
||||
|
||||
|
||||
class LInfNorm(tf.keras.losses.Loss):
|
||||
class LInfNorm(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
|
||||
""" Calculate the L-inf norm as a loss function. """
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
@classmethod
|
||||
def call(cls, y_true, y_pred):
|
||||
""" Call the L-inf norm loss function.
|
||||
|
||||
Parameters
|
||||
|
|
@ -271,7 +145,7 @@ class LInfNorm(tf.keras.losses.Loss):
|
|||
return loss
|
||||
|
||||
|
||||
class GradientLoss(tf.keras.losses.Loss):
|
||||
class GradientLoss(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
|
||||
""" Gradient Loss Function.
|
||||
|
||||
Calculates the first and second order gradient difference between pixels of an image in the x
|
||||
|
|
@ -392,7 +266,7 @@ class GradientLoss(tf.keras.losses.Loss):
|
|||
return (xy_out1 - xy_out2) * 0.25
|
||||
|
||||
|
||||
class GMSDLoss(tf.keras.losses.Loss):
|
||||
class GMSDLoss(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
|
||||
""" Gradient Magnitude Similarity Deviation Loss.
|
||||
|
||||
Improved image quality metric over MS-SSIM with easier calculations
|
||||
|
|
@ -486,7 +360,9 @@ class GMSDLoss(tf.keras.losses.Loss):
|
|||
# Use depth-wise convolution to calculate edge maps per channel.
|
||||
# Output tensor has shape [batch_size, h, w, d * num_kernels].
|
||||
pad_sizes = [[0, 0], [2, 2], [2, 2], [0, 0]]
|
||||
padded = tf.pad(image, pad_sizes, mode='REFLECT')
|
||||
padded = tf.pad(image, # pylint:disable=unexpected-keyword-arg,no-value-for-parameter
|
||||
pad_sizes,
|
||||
mode='REFLECT')
|
||||
output = K.depthwise_conv2d(padded, kernels)
|
||||
|
||||
if not magnitude: # direction of edges
|
||||
|
|
|
|||
|
|
@ -7,15 +7,16 @@ import inspect
|
|||
from keras.layers import Layer, InputSpec
|
||||
from keras import initializers, regularizers, constraints
|
||||
from keras import backend as K
|
||||
from keras.utils import get_custom_objects
|
||||
|
||||
from lib.utils import get_backend
|
||||
|
||||
from lib.utils import get_backend, get_keras_custom_objects as get_custom_objects
|
||||
|
||||
if get_backend() == "amd":
|
||||
from keras.backend import normalize_data_format # pylint:disable=ungrouped-imports
|
||||
from keras.backend \
|
||||
import normalize_data_format # pylint:disable=ungrouped-imports,no-name-in-module
|
||||
else:
|
||||
from tensorflow.python.keras.utils.conv_utils import normalize_data_format
|
||||
# pylint:disable=no-name-in-module
|
||||
from tensorflow.python.keras.utils.conv_utils \
|
||||
import normalize_data_format # pylint:disable=no-name-in-module
|
||||
|
||||
|
||||
class InstanceNormalization(Layer):
|
||||
|
|
@ -61,6 +62,7 @@ class InstanceNormalization(Layer):
|
|||
- Instance Normalization: The Missing Ingredient for Fast Stylization - \
|
||||
https://arxiv.org/abs/1607.08022
|
||||
"""
|
||||
# pylint:disable=too-many-instance-attributes,too-many-arguments
|
||||
def __init__(self,
|
||||
axis=None,
|
||||
epsilon=1e-3,
|
||||
|
|
@ -348,6 +350,7 @@ class GroupNormalization(Layer):
|
|||
----------
|
||||
Shaoanlu GAN: https://github.com/shaoanlu/faceswap-GAN
|
||||
"""
|
||||
# pylint:disable=too-many-instance-attributes
|
||||
def __init__(self, axis=-1, gamma_init='one', beta_init='zero', gamma_regularizer=None,
|
||||
beta_regularizer=None, epsilon=1e-6, group=32, data_format=None, **kwargs):
|
||||
self.beta = None
|
||||
|
|
|
|||
|
|
@ -4,10 +4,14 @@ import inspect
|
|||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras.backend as K
|
||||
import tensorflow.keras.backend as K # pylint:disable=no-name-in-module,import-error
|
||||
# tf.keras has a LayerNormaliztion implementation
|
||||
from tensorflow.keras.layers import Layer, LayerNormalization # noqa pylint:disable=unused-import
|
||||
from tensorflow.keras.utils import get_custom_objects
|
||||
# pylint:disable=unused-import
|
||||
from tensorflow.keras.layers import ( # noqa pylint:disable=no-name-in-module,import-error
|
||||
Layer,
|
||||
LayerNormalization)
|
||||
|
||||
from lib.utils import get_keras_custom_objects as get_custom_objects
|
||||
|
||||
|
||||
class RMSNormalization(Layer):
|
||||
|
|
@ -117,9 +121,10 @@ class RMSNormalization(Layer):
|
|||
mean_square = K.mean(K.square(inputs), axis=self.axis, keepdims=True)
|
||||
else:
|
||||
partial_size = int(layer_size * self.partial)
|
||||
partial_x, _ = tf.split(inputs,
|
||||
[partial_size, layer_size - partial_size],
|
||||
axis=self.axis)
|
||||
partial_x, _ = tf.split( # pylint:disable=redundant-keyword-arg,no-value-for-parameter
|
||||
inputs,
|
||||
[partial_size, layer_size - partial_size],
|
||||
axis=self.axis)
|
||||
mean_square = K.mean(K.square(partial_x), axis=self.axis, keepdims=True)
|
||||
|
||||
recip_square_root = tf.math.rsqrt(mean_square + self.epsilon)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import inspect
|
|||
import sys
|
||||
|
||||
from keras import backend as K
|
||||
from keras.optimizers import Optimizer
|
||||
from keras.optimizers import Optimizer, Adam, Nadam, RMSprop # noqa pylint:disable=unused-import
|
||||
from keras.utils import get_custom_objects
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,10 @@ import inspect
|
|||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
from keras.utils import get_custom_objects
|
||||
from tensorflow.keras.optimizers import ( # noqa pylint:disable=no-name-in-module,unused-import,import-error
|
||||
Adam, Nadam, RMSprop)
|
||||
|
||||
from lib.utils import get_keras_custom_objects as get_custom_objects
|
||||
|
||||
|
||||
class AdaBelief(tf.keras.optimizers.Optimizer):
|
||||
|
|
@ -128,6 +131,7 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
|
|||
def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-14,
|
||||
weight_decay=0.0, rectify=True, amsgrad=False, sma_threshold=5.0, total_steps=0,
|
||||
warmup_proportion=0.1, min_lr=0.0, name="AdaBeliefOptimizer", **kwargs):
|
||||
# pylint:disable=too-many-arguments
|
||||
super().__init__(name, **kwargs)
|
||||
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
|
||||
self._set_hyper("beta_1", beta_1)
|
||||
|
|
@ -196,7 +200,7 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
|
|||
return wd_t
|
||||
|
||||
def _resource_apply_dense(self, grad, handle, apply_state=None):
|
||||
# pylint:disable=too-many-locals
|
||||
# pylint:disable=too-many-locals,unused-argument
|
||||
""" Add ops to apply dense gradients to the variable handle.
|
||||
|
||||
Parameters
|
||||
|
|
@ -274,7 +278,7 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
|
|||
return tf.group(*updates)
|
||||
|
||||
def _resource_apply_sparse(self, grad, handle, indices, apply_state=None):
|
||||
# pylint:disable=too-many-locals
|
||||
# pylint:disable=too-many-locals, unused-argument
|
||||
""" Add ops to apply sparse gradients to the variable handle.
|
||||
|
||||
Similar to _apply_sparse, the indices argument to this method has been de-duplicated.
|
||||
|
|
@ -324,7 +328,7 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
|
|||
m_corr_t = m_t / (1.0 - beta_1_power)
|
||||
|
||||
var_v = self.get_slot(handle, "v")
|
||||
m_t_indices = tf.gather(m_t, indices)
|
||||
m_t_indices = tf.gather(m_t, indices) # pylint:disable=no-value-for-parameter
|
||||
v_scaled_g_values = tf.math.square(grad - m_t_indices) * (1 - beta_2_t)
|
||||
v_t = var_v.assign(var_v * beta_2_t + epsilon_t, use_locking=self._use_locking)
|
||||
v_t = self._resource_scatter_add(var_v, indices, v_scaled_g_values)
|
||||
|
|
@ -355,7 +359,9 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
|
|||
|
||||
var_update = self._resource_scatter_add(handle,
|
||||
indices,
|
||||
tf.gather(tf.math.negative(lr_t) * var_t, indices))
|
||||
tf.gather( # pylint:disable=no-value-for-parameter
|
||||
tf.math.negative(lr_t) * var_t,
|
||||
indices))
|
||||
|
||||
updates = [var_update, m_t, v_t]
|
||||
if self.amsgrad:
|
||||
|
|
@ -391,6 +397,6 @@ class AdaBelief(tf.keras.optimizers.Optimizer):
|
|||
|
||||
|
||||
# Update layers into Keras custom objects
|
||||
for name, obj in inspect.getmembers(sys.modules[__name__]):
|
||||
for _name, obj in inspect.getmembers(sys.modules[__name__]):
|
||||
if inspect.isclass(obj) and obj.__module__ == __name__:
|
||||
get_custom_objects().update({name: obj})
|
||||
get_custom_objects().update({_name: obj})
|
||||
|
|
|
|||
70
lib/utils.py
70
lib/utils.py
|
|
@ -22,6 +22,7 @@ _image_extensions = [ # pylint:disable=invalid-name
|
|||
_video_extensions = [ # pylint:disable=invalid-name
|
||||
".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv",
|
||||
".ts", ".vob"]
|
||||
_TF_VERS = None
|
||||
|
||||
|
||||
class _Backend(): # pylint:disable=too-few-public-methods
|
||||
|
|
@ -60,8 +61,7 @@ class _Backend(): # pylint:disable=too-few-public-methods
|
|||
# Check if environment variable is set, if so use that
|
||||
if "FACESWAP_BACKEND" in os.environ:
|
||||
fs_backend = os.environ["FACESWAP_BACKEND"].lower()
|
||||
print("Setting Faceswap backend from environment variable to "
|
||||
"{}".format(fs_backend.upper()))
|
||||
print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}")
|
||||
return fs_backend
|
||||
# Intercept for sphinx docs build
|
||||
if sys.argv[0].endswith("sphinx-build"):
|
||||
|
|
@ -70,7 +70,7 @@ class _Backend(): # pylint:disable=too-few-public-methods
|
|||
self._configure_backend()
|
||||
while True:
|
||||
try:
|
||||
with open(self._config_file, "r") as cnf:
|
||||
with open(self._config_file, "r", encoding="utf8") as cnf:
|
||||
config = json.load(cnf)
|
||||
break
|
||||
except json.decoder.JSONDecodeError:
|
||||
|
|
@ -80,7 +80,7 @@ class _Backend(): # pylint:disable=too-few-public-methods
|
|||
if fs_backend is None or fs_backend.lower() not in self._backends.values():
|
||||
fs_backend = self._configure_backend()
|
||||
if current_process().name == "MainProcess":
|
||||
print("Setting Faceswap backend to {}".format(fs_backend.upper()))
|
||||
print(f"Setting Faceswap backend to {fs_backend.upper()}")
|
||||
return fs_backend.lower()
|
||||
|
||||
def _configure_backend(self):
|
||||
|
|
@ -95,14 +95,14 @@ class _Backend(): # pylint:disable=too-few-public-methods
|
|||
while True:
|
||||
selection = input("1: AMD, 2: CPU, 3: NVIDIA: ")
|
||||
if selection not in ("1", "2", "3"):
|
||||
print("'{}' is not a valid selection. Please try again".format(selection))
|
||||
print(f"'{selection}' is not a valid selection. Please try again")
|
||||
continue
|
||||
break
|
||||
fs_backend = self._backends[selection].lower()
|
||||
config = {"backend": fs_backend}
|
||||
with open(self._config_file, "w") as cnf:
|
||||
with open(self._config_file, "w", encoding="utf8") as cnf:
|
||||
json.dump(config, cnf)
|
||||
print("Faceswap config written to: {}".format(self._config_file))
|
||||
print(f"Faceswap config written to: {self._config_file}")
|
||||
return fs_backend
|
||||
|
||||
|
||||
|
|
@ -132,6 +132,32 @@ def set_backend(backend):
|
|||
_FS_BACKEND = backend.lower()
|
||||
|
||||
|
||||
def get_tf_version():
|
||||
""" Obtain the major.minor version of currently installed Tensorflow.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The currently installed tensorflow version
|
||||
"""
|
||||
global _TF_VERS # pylint:disable=global-statement
|
||||
if _TF_VERS is None:
|
||||
import tensorflow as tf # pylint:disable=import-outside-toplevel
|
||||
_TF_VERS = float(".".join(tf.__version__.split(".")[:2])) # pylint:disable=no-member
|
||||
return _TF_VERS
|
||||
|
||||
|
||||
def get_keras_custom_objects():
|
||||
""" Wrapper to obtain keras.utils.get_custom_objects from correct location depending on
|
||||
backend used and tensorflow version. """
|
||||
# pylint:disable=no-name-in-module,import-outside-toplevel
|
||||
if get_backend() == "amd" or get_tf_version() < 2.8:
|
||||
from keras.utils import get_custom_objects
|
||||
else:
|
||||
from keras.utils.generic_utils import get_custom_objects
|
||||
return get_custom_objects()
|
||||
|
||||
|
||||
def get_folder(path, make_folder=True):
|
||||
""" Return a path to a folder, creating it if it doesn't exist
|
||||
|
||||
|
|
@ -176,7 +202,7 @@ def get_image_paths(directory, extension=None):
|
|||
"""
|
||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
||||
image_extensions = _image_extensions if extension is None else [extension]
|
||||
dir_contents = list()
|
||||
dir_contents = []
|
||||
|
||||
if not os.path.exists(directory):
|
||||
logger.debug("Creating folder: '%s'", directory)
|
||||
|
|
@ -242,7 +268,7 @@ def full_path_split(path):
|
|||
>>> ["foo", "baz", "bar"]
|
||||
"""
|
||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
||||
allparts = list()
|
||||
allparts = []
|
||||
while True:
|
||||
parts = os.path.split(path)
|
||||
if parts[0] == path: # sentinel for absolute paths
|
||||
|
|
@ -297,9 +323,9 @@ def deprecation_warning(function, additional_info=None):
|
|||
"""
|
||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
||||
logger.debug("func_name: %s, additional_info: %s", function, additional_info)
|
||||
msg = "{} has been deprecated and will be removed from a future update.".format(function)
|
||||
msg = f"{function} has been deprecated and will be removed from a future update."
|
||||
if additional_info is not None:
|
||||
msg += " {}".format(additional_info)
|
||||
msg += f" {additional_info}"
|
||||
logger.warning(msg)
|
||||
|
||||
|
||||
|
|
@ -355,7 +381,7 @@ class FaceswapError(Exception):
|
|||
pass # pylint:disable=unnecessary-pass
|
||||
|
||||
|
||||
class GetModel(): # Pylint:disable=too-few-public-methods
|
||||
class GetModel(): # pylint:disable=too-few-public-methods
|
||||
""" Check for models in their cache path.
|
||||
|
||||
If available, return the path, if not available, get, unzip and install model
|
||||
|
|
@ -428,7 +454,7 @@ class GetModel(): # Pylint:disable=too-few-public-methods
|
|||
@property
|
||||
def _model_zip_path(self):
|
||||
""" str: The full path to downloaded zip file. """
|
||||
retval = os.path.join(self._cache_dir, "{}.zip".format(self._model_full_name))
|
||||
retval = os.path.join(self._cache_dir, f"{self._model_full_name}.zip")
|
||||
self.logger.trace(retval)
|
||||
return retval
|
||||
|
||||
|
|
@ -462,8 +488,8 @@ class GetModel(): # Pylint:disable=too-few-public-methods
|
|||
@property
|
||||
def _url_download(self):
|
||||
""" strL Base download URL for models. """
|
||||
tag = "v{}.{}.{}".format(self._url_section, self._git_model_id, self._model_version)
|
||||
retval = "{}/{}/{}.zip".format(self._url_base, tag, self._model_full_name)
|
||||
tag = f"v{self._url_section}.{self._git_model_id}.{self._model_version}"
|
||||
retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip"
|
||||
self.logger.trace("Download url: %s", retval)
|
||||
return retval
|
||||
|
||||
|
|
@ -493,11 +519,11 @@ class GetModel(): # Pylint:disable=too-few-public-methods
|
|||
downloaded_size = self._url_partial_size
|
||||
req = urllib.request.Request(self._url_download)
|
||||
if downloaded_size != 0:
|
||||
req.add_header("Range", "bytes={}-".format(downloaded_size))
|
||||
response = urllib.request.urlopen(req, timeout=10)
|
||||
self.logger.debug("header info: {%s}", response.info())
|
||||
self.logger.debug("Return Code: %s", response.getcode())
|
||||
self._write_zipfile(response, downloaded_size)
|
||||
req.add_header("Range", f"bytes={downloaded_size}-")
|
||||
with urllib.request.urlopen(req, timeout=10) as response:
|
||||
self.logger.debug("header info: {%s}", response.info())
|
||||
self.logger.debug("Return Code: %s", response.getcode())
|
||||
self._write_zipfile(response, downloaded_size)
|
||||
break
|
||||
except (socket_error, socket_timeout,
|
||||
urllib.error.HTTPError, urllib.error.URLError) as err:
|
||||
|
|
@ -548,8 +574,8 @@ class GetModel(): # Pylint:disable=too-few-public-methods
|
|||
""" Unzip the model file to the cache folder """
|
||||
self.logger.info("Extracting: '%s'", self._model_name)
|
||||
try:
|
||||
zip_file = zipfile.ZipFile(self._model_zip_path, "r")
|
||||
self._write_model(zip_file)
|
||||
with zipfile.ZipFile(self._model_zip_path, "r") as zip_file:
|
||||
self._write_model(zip_file)
|
||||
except Exception as err: # pylint:disable=broad-except
|
||||
self.logger.error("Unable to extract model file: %s", str(err))
|
||||
sys.exit(1)
|
||||
|
|
|
|||
|
|
@ -20,13 +20,12 @@ from keras import losses as k_losses
|
|||
from keras import backend as K
|
||||
from keras.layers import Input
|
||||
from keras.models import load_model, Model as KModel
|
||||
from keras.optimizers import Adam, Nadam, RMSprop
|
||||
|
||||
from lib.serializer import get_serializer
|
||||
from lib.model.backup_restore import Backup
|
||||
from lib.model import losses, optimizers
|
||||
from lib.model.nn_blocks import set_config as set_nnblock_config
|
||||
from lib.utils import get_backend, FaceswapError
|
||||
from lib.utils import get_backend, get_tf_version, FaceswapError
|
||||
from plugins.train._config import Config
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
|
@ -264,13 +263,14 @@ class ModelBase():
|
|||
return
|
||||
|
||||
if len(multiple_models) == 1:
|
||||
msg = ("You have requested to train with the '{}' plugin, but a model file for the "
|
||||
"'{}' plugin already exists in the folder '{}'.\nPlease select a different "
|
||||
"model folder.".format(self.name, multiple_models[0], self.model_dir))
|
||||
msg = (f"You have requested to train with the '{self.name}' plugin, but a model file "
|
||||
f"for the '{multiple_models[0]}' plugin already exists in the folder "
|
||||
f"'{self.model_dir}'.\nPlease select a different model folder.")
|
||||
else:
|
||||
msg = ("There are multiple plugin types ('{}') stored in the model folder '{}'. This "
|
||||
"is not supported.\nPlease split the model files into their own folders before "
|
||||
"proceeding".format("', '".join(multiple_models), self.model_dir))
|
||||
ptypes = "', '".join(multiple_models)
|
||||
msg = (f"There are multiple plugin types ('{ptypes}') stored in the model folder '"
|
||||
f"{self.model_dir}'. This is not supported.\nPlease split the model files into "
|
||||
"their own folders before proceeding")
|
||||
raise FaceswapError(msg)
|
||||
|
||||
def build(self):
|
||||
|
|
@ -311,11 +311,11 @@ class ModelBase():
|
|||
if not all(os.path.isfile(os.path.join(self.model_dir, fname))
|
||||
for fname in self._legacy_mapping()):
|
||||
return
|
||||
archive_dir = "{}_TF1_Archived".format(self.model_dir)
|
||||
archive_dir = f"{self.model_dir}_TF1_Archived"
|
||||
if os.path.exists(archive_dir):
|
||||
raise FaceswapError("We need to update your model files for use with Tensorflow 2.x, "
|
||||
"but the archive folder already exists. Please remove the "
|
||||
"following folder to continue: '{}'".format(archive_dir))
|
||||
f"following folder to continue: '{archive_dir}'")
|
||||
|
||||
logger.info("Updating legacy models for Tensorflow 2.x")
|
||||
logger.info("Your Tensorflow 1.x models will be archived in the following location: '%s'",
|
||||
|
|
@ -371,7 +371,7 @@ class ModelBase():
|
|||
input_shapes = [self.input_shape, self.input_shape]
|
||||
else:
|
||||
input_shapes = self.input_shape
|
||||
inputs = [Input(shape=shape, name="face_in_{}".format(side))
|
||||
inputs = [Input(shape=shape, name=f"face_in_{side}")
|
||||
for side, shape in zip(("a", "b"), input_shapes)]
|
||||
logger.debug("inputs: %s", inputs)
|
||||
return inputs
|
||||
|
|
@ -450,7 +450,7 @@ class ModelBase():
|
|||
seen = {name: 0 for name in set(self._model.output_names)}
|
||||
new_names = []
|
||||
for name in self._model.output_names:
|
||||
new_names.append("{}_{}".format(name, seen[name]))
|
||||
new_names.append(f"{name}_{seen[name]}")
|
||||
seen[name] += 1
|
||||
logger.debug("Output names rewritten: (old: %s, new: %s)",
|
||||
self._model.output_names, new_names)
|
||||
|
|
@ -510,7 +510,7 @@ class _IO():
|
|||
@property
|
||||
def _filename(self):
|
||||
"""str: The filename for this model."""
|
||||
return os.path.join(self._model_dir, "{}.h5".format(self._plugin.name))
|
||||
return os.path.join(self._model_dir, f"{self._plugin.name}.h5")
|
||||
|
||||
@property
|
||||
def model_exists(self):
|
||||
|
|
@ -567,7 +567,7 @@ class _IO():
|
|||
"You can try to load the model again but if the problem persists you "
|
||||
"should use the Restore Tool to restore your model from backup.\n"
|
||||
f"Original error: {str(err)}")
|
||||
raise FaceswapError(msg)
|
||||
raise FaceswapError(msg) from err
|
||||
raise err
|
||||
except KeyError as err:
|
||||
if "unable to open object" in str(err).lower():
|
||||
|
|
@ -576,7 +576,7 @@ class _IO():
|
|||
"You can try to load the model again but if the problem persists you "
|
||||
"should use the Restore Tool to restore your model from backup.\n"
|
||||
f"Original error: {str(err)}")
|
||||
raise FaceswapError(msg)
|
||||
raise FaceswapError(msg) from err
|
||||
raise err
|
||||
|
||||
logger.info("Loaded model from disk: '%s'", self._filename)
|
||||
|
|
@ -605,9 +605,9 @@ class _IO():
|
|||
|
||||
msg = "[Saved models]"
|
||||
if save_averages:
|
||||
lossmsg = ["face_{}: {:.5f}".format(side, avg)
|
||||
lossmsg = [f"face_{side}: {avg:.5f}"
|
||||
for side, avg in zip(("a", "b"), save_averages)]
|
||||
msg += " - Average loss since last save: {}".format(", ".join(lossmsg))
|
||||
msg += f" - Average loss since last save: {', '.join(lossmsg)}"
|
||||
logger.info(msg)
|
||||
|
||||
def _get_save_averages(self):
|
||||
|
|
@ -699,12 +699,11 @@ class _Settings():
|
|||
logger.debug("Initializing %s: (arguments: %s, mixed_precision: %s, allow_growth: %s, "
|
||||
"is_predict: %s)", self.__class__.__name__, arguments, mixed_precision,
|
||||
allow_growth, is_predict)
|
||||
self._tf_version = [int(i) for i in tf.__version__.split(".")[:2]]
|
||||
self._set_tf_settings(allow_growth, arguments.exclude_gpus)
|
||||
|
||||
use_mixed_precision = not is_predict and mixed_precision and get_backend() == "nvidia"
|
||||
# Mixed precision moved out of experimental in tensorflow 2.4
|
||||
if use_mixed_precision and self._tf_version[0] == 2 and self._tf_version[1] < 4:
|
||||
if use_mixed_precision and get_tf_version() < 2.4:
|
||||
self._mixed_precision = tf.keras.mixed_precision.experimental
|
||||
elif use_mixed_precision:
|
||||
self._mixed_precision = tf.keras.mixed_precision
|
||||
|
|
@ -743,9 +742,8 @@ class _Settings():
|
|||
"""
|
||||
# tensorflow versions < 2.4 had different kwargs where scaling needs to be explicitly
|
||||
# defined
|
||||
vers = self._tf_version
|
||||
kwargs = dict(loss_scale="dynamic") if vers[0] == 2 and vers[1] < 4 else dict()
|
||||
logger.debug("tf version: %s, kwargs: %s", vers, kwargs)
|
||||
kwargs = dict(loss_scale="dynamic") if get_tf_version() < 2.4 else {}
|
||||
logger.debug("tf version: %s, kwargs: %s", get_tf_version(), kwargs)
|
||||
return self._mixed_precision.LossScaleOptimizer(optimizer, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -821,10 +819,11 @@ class _Settings():
|
|||
return False
|
||||
logger.info("Enabling Mixed Precision Training.")
|
||||
|
||||
if exclude_gpus and self._tf_version[0] == 2 and self._tf_version[1] == 2:
|
||||
if exclude_gpus and get_tf_version() == 2.2:
|
||||
# TODO remove this hacky fix to disable mixed precision compatibility testing when
|
||||
# tensorflow 2.2 support dropped
|
||||
# pylint:disable=import-outside-toplevel,protected-access,import-error
|
||||
# pylint:disable=import-outside-toplevel,protected-access
|
||||
# pylint:disable=import-error,no-name-in-module
|
||||
from tensorflow.python.keras.mixed_precision.experimental import \
|
||||
device_compatibility_check
|
||||
logger.debug("Overriding tensorflow _logged_compatibility_check parameter. Initial "
|
||||
|
|
@ -833,7 +832,7 @@ class _Settings():
|
|||
logger.debug("New value: %s", device_compatibility_check._logged_compatibility_check)
|
||||
|
||||
policy = self._mixed_precision.Policy('mixed_float16')
|
||||
if self._tf_version[0] == 2 and self._tf_version[1] < 4:
|
||||
if get_tf_version() < 2.4:
|
||||
self._mixed_precision.set_policy(policy)
|
||||
else:
|
||||
self._mixed_precision.set_global_policy(policy)
|
||||
|
|
@ -1102,9 +1101,11 @@ class _Optimizer(): # pylint:disable=too-few-public-methods
|
|||
optimizer, learning_rate, clipnorm, epsilon, arguments)
|
||||
valid_optimizers = {"adabelief": (optimizers.AdaBelief,
|
||||
dict(beta_1=0.5, beta_2=0.99, epsilon=epsilon)),
|
||||
"adam": (Adam, dict(beta_1=0.5, beta_2=0.99, epsilon=epsilon)),
|
||||
"nadam": (Nadam, dict(beta_1=0.5, beta_2=0.99, epsilon=epsilon)),
|
||||
"rms-prop": (RMSprop, dict(epsilon=epsilon))}
|
||||
"adam": (optimizers.Adam,
|
||||
dict(beta_1=0.5, beta_2=0.99, epsilon=epsilon)),
|
||||
"nadam": (optimizers.Nadam,
|
||||
dict(beta_1=0.5, beta_2=0.99, epsilon=epsilon)),
|
||||
"rms-prop": (optimizers.RMSprop, dict(epsilon=epsilon))}
|
||||
self._optimizer, self._kwargs = valid_optimizers[optimizer]
|
||||
|
||||
self._configure(learning_rate, clipnorm, arguments)
|
||||
|
|
@ -1173,7 +1174,7 @@ class _Loss():
|
|||
self._uses_l2_reg = ["ssim", "gmsd"]
|
||||
self._inputs = None
|
||||
self._names = []
|
||||
self._funcs = dict()
|
||||
self._funcs = {}
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
|
|
@ -1248,7 +1249,7 @@ class _Loss():
|
|||
side, output_names, output_shapes, output_types)
|
||||
self._names.extend(["{}_{}{}".format(name, side,
|
||||
"" if output_types.count(name) == 1
|
||||
else "_{}".format(idx))
|
||||
else f"_{idx}")
|
||||
for idx, name in enumerate(output_types)])
|
||||
logger.debug(self._names)
|
||||
|
||||
|
|
@ -1354,13 +1355,13 @@ class State():
|
|||
"config_changeable_items: '%s', no_logs: %s", self.__class__.__name__,
|
||||
model_dir, model_name, config_changeable_items, no_logs)
|
||||
self._serializer = get_serializer("json")
|
||||
filename = "{}_state.{}".format(model_name, self._serializer.file_extension)
|
||||
filename = f"{model_name}_state.{self._serializer.file_extension}"
|
||||
self._filename = os.path.join(model_dir, filename)
|
||||
self._name = model_name
|
||||
self._iterations = 0
|
||||
self._sessions = dict()
|
||||
self._lowest_avg_loss = dict()
|
||||
self._config = dict()
|
||||
self._sessions = {}
|
||||
self._lowest_avg_loss = {}
|
||||
self._config = {}
|
||||
self._load(config_changeable_items)
|
||||
self._session_id = self._new_session_id()
|
||||
self._create_new_session(no_logs, config_changeable_items)
|
||||
|
|
@ -1473,10 +1474,10 @@ class State():
|
|||
return
|
||||
state = self._serializer.load(self._filename)
|
||||
self._name = state.get("name", self._name)
|
||||
self._sessions = state.get("sessions", dict())
|
||||
self._lowest_avg_loss = state.get("lowest_avg_loss", dict())
|
||||
self._sessions = state.get("sessions", {})
|
||||
self._lowest_avg_loss = state.get("lowest_avg_loss", {})
|
||||
self._iterations = state.get("iterations", 0)
|
||||
self._config = state.get("config", dict())
|
||||
self._config = state.get("config", {})
|
||||
logger.debug("Loaded state: %s", state)
|
||||
self._replace_config(config_changeable_items)
|
||||
|
||||
|
|
@ -1670,7 +1671,7 @@ class _Inference(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Compiling inference model. saved_model: %s", saved_model)
|
||||
struct = self._get_filtered_structure()
|
||||
model_inputs = self._get_inputs(saved_model.inputs)
|
||||
compiled_layers = dict()
|
||||
compiled_layers = {}
|
||||
for layer in saved_model.layers:
|
||||
if layer.name not in struct:
|
||||
logger.debug("Skipping unused layer: '%s'", layer.name)
|
||||
|
|
@ -1704,7 +1705,7 @@ class _Inference(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Compiling layer '%s': layer inputs: %s", layer.name, layer_inputs)
|
||||
model = layer(layer_inputs)
|
||||
compiled_layers[layer.name] = model
|
||||
retval = KerasModel(model_inputs, model, name="{}_inference".format(saved_model.name))
|
||||
retval = KerasModel(model_inputs, model, name=f"{saved_model.name}_inference")
|
||||
logger.debug("Compiled inference model '%s': %s", retval.name, retval)
|
||||
return retval
|
||||
|
||||
|
|
|
|||
|
|
@ -16,10 +16,11 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework import errors_impl as tf_errors
|
||||
from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
|
||||
errors_impl as tf_errors)
|
||||
|
||||
from lib.training import TrainingDataGenerator
|
||||
from lib.utils import FaceswapError, get_backend, get_folder, get_image_paths
|
||||
from lib.utils import FaceswapError, get_backend, get_folder, get_image_paths, get_tf_version
|
||||
from plugins.train._config import Config
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
|
@ -129,8 +130,8 @@ class TrainerBase():
|
|||
|
||||
logger.debug("Setting up TensorBoard Logging")
|
||||
log_dir = os.path.join(str(self._model.model_dir),
|
||||
"{}_logs".format(self._model.name),
|
||||
"session_{}".format(self._model.state.session_id))
|
||||
f"{self._model.name}_logs",
|
||||
f"session_{self._model.state.session_id}")
|
||||
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=log_dir,
|
||||
histogram_freq=0, # Must be 0 or hangs
|
||||
write_graph=get_backend() != "amd",
|
||||
|
|
@ -251,7 +252,16 @@ class TrainerBase():
|
|||
logger.trace("Updating TensorBoard log")
|
||||
logs = {log[0]: log[1]
|
||||
for log in zip(self._model.state.loss_names, loss)}
|
||||
|
||||
self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs)
|
||||
if get_tf_version() == 2.8:
|
||||
# Bug in TF 2.8 where batch recording got deleted.
|
||||
# ref: https://github.com/keras-team/keras/issues/16173
|
||||
for name, value in logs.items():
|
||||
tf.summary.scalar(
|
||||
"batch_" + name,
|
||||
value,
|
||||
step=self._model._model._train_counter) # pylint:disable=protected-access
|
||||
|
||||
def _collate_and_store_loss(self, loss):
|
||||
""" Collate the loss into totals for each side.
|
||||
|
|
@ -297,11 +307,11 @@ class TrainerBase():
|
|||
The loss for each side. List should contain 2 ``floats`` side "a" in position 0 and
|
||||
side "b" in position `.
|
||||
"""
|
||||
output = ", ".join(["Loss {}: {:.5f}".format(side, side_loss)
|
||||
output = ", ".join([f"Loss {side}: {side_loss:.5f}"
|
||||
for side, side_loss in zip(("A", "B"), loss)])
|
||||
timestamp = time.strftime("%H:%M:%S")
|
||||
output = "[{}] [#{:05d}] {}".format(timestamp, self._model.iterations, output)
|
||||
print("\r{}".format(output), end="")
|
||||
output = f"[{timestamp}] [#{self._model.iterations:05d}] {output}"
|
||||
print(f"\r{output}", end="")
|
||||
|
||||
def clear_tensorboard(self):
|
||||
""" Stop Tensorboard logging.
|
||||
|
|
@ -335,14 +345,14 @@ class _Feeder():
|
|||
self._model = model
|
||||
self._images = images
|
||||
self._config = config
|
||||
self._target = dict()
|
||||
self._samples = dict()
|
||||
self._masks = dict()
|
||||
self._target = {}
|
||||
self._samples = {}
|
||||
self._masks = {}
|
||||
|
||||
self._feeds = {side: self._load_generator(idx).minibatch_ab(images[side], batch_size, side)
|
||||
for idx, side in enumerate(("a", "b"))}
|
||||
|
||||
self._display_feeds = dict(preview=self._set_preview_feed(), timelapse=dict())
|
||||
self._display_feeds = dict(preview=self._set_preview_feed(), timelapse={})
|
||||
logger.debug("Initialized %s:", self.__class__.__name__)
|
||||
|
||||
def _load_generator(self, output_index):
|
||||
|
|
@ -385,7 +395,7 @@ class _Feeder():
|
|||
The side ("a" or "b") as key, :class:`~lib.training_data.TrainingDataGenerator` as
|
||||
value.
|
||||
"""
|
||||
retval = dict()
|
||||
retval = {}
|
||||
for idx, side in enumerate(("a", "b")):
|
||||
logger.debug("Setting preview feed: (side: '%s')", side)
|
||||
preview_images = self._config.get("preview_images", 14)
|
||||
|
|
@ -484,9 +494,9 @@ class _Feeder():
|
|||
should not be generated, in which case currently stored previews should be deleted.
|
||||
"""
|
||||
if not do_preview:
|
||||
self._samples = dict()
|
||||
self._target = dict()
|
||||
self._masks = dict()
|
||||
self._samples = {}
|
||||
self._target = {}
|
||||
self._masks = {}
|
||||
return
|
||||
logger.debug("Generating preview")
|
||||
for side in ("a", "b"):
|
||||
|
|
@ -523,7 +533,7 @@ class _Feeder():
|
|||
"""
|
||||
num_images = self._config.get("preview_images", 14)
|
||||
num_images = min(batch_size, num_images) if batch_size is not None else num_images
|
||||
retval = dict()
|
||||
retval = {}
|
||||
for side in ("a", "b"):
|
||||
logger.debug("Compiling samples: (side: '%s', samples: %s)", side, num_images)
|
||||
side_images = images[side] if images is not None else self._target[side]
|
||||
|
|
@ -544,9 +554,9 @@ class _Feeder():
|
|||
:class:`numpy.ndarrays` for creating a time-lapse frame
|
||||
"""
|
||||
batchsizes = []
|
||||
samples = dict()
|
||||
images = dict()
|
||||
masks = dict()
|
||||
samples = {}
|
||||
images = {}
|
||||
masks = {}
|
||||
for side in ("a", "b"):
|
||||
batch = next(self._display_feeds["timelapse"][side])
|
||||
batchsizes.append(len(batch["samples"]))
|
||||
|
|
@ -607,7 +617,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
self.__class__.__name__, model, coverage_ratio)
|
||||
self._model = model
|
||||
self._display_mask = model.config["learn_mask"] or model.config["penalized_mask_loss"]
|
||||
self.images = dict()
|
||||
self.images = {}
|
||||
self._coverage_ratio = coverage_ratio
|
||||
self._scaling = scaling
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
|
@ -630,9 +640,9 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
A compiled preview image ready for display or saving
|
||||
"""
|
||||
logger.debug("Showing sample")
|
||||
feeds = dict()
|
||||
figures = dict()
|
||||
headers = dict()
|
||||
feeds = {}
|
||||
figures = {}
|
||||
headers = {}
|
||||
for idx, side in enumerate(("a", "b")):
|
||||
samples = self.images[side]
|
||||
faces = samples[1]
|
||||
|
|
@ -647,8 +657,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
|
||||
for side, samples in self.images.items():
|
||||
other_side = "a" if side == "b" else "b"
|
||||
predictions = [preds["{0}_{0}".format(side)],
|
||||
preds["{}_{}".format(other_side, side)]]
|
||||
predictions = [preds[f"{side}_{side}"],
|
||||
preds[f"{other_side}_{side}"]]
|
||||
display = self._to_full_frame(side, samples, predictions)
|
||||
headers[side] = self._get_headers(side, display[0].shape[1])
|
||||
figures[side] = np.stack([display[0], display[1], display[2], ], axis=1)
|
||||
|
|
@ -716,7 +726,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
List of :class:`numpy.ndarray` of predictions received from the model
|
||||
"""
|
||||
logger.debug("Getting Predictions")
|
||||
preds = dict()
|
||||
preds = {}
|
||||
standard = self._model.model.predict([feed_a, feed_b])
|
||||
swapped = self._model.model.predict([feed_b, feed_a])
|
||||
|
||||
|
|
@ -904,9 +914,9 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
total_width = width * 3
|
||||
logger.debug("height: %s, total_width: %s", height, total_width)
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
texts = ["{} ({})".format(titles[0], side),
|
||||
"{0} > {0}".format(titles[0]),
|
||||
"{} > {}".format(titles[0], titles[1])]
|
||||
texts = [f"{titles[0]} ({side})",
|
||||
f"{titles[0]} > {titles[0]}",
|
||||
f"{titles[0]} > {titles[1]}"]
|
||||
scaling = (width / 144) * 0.45
|
||||
text_sizes = [cv2.getTextSize(texts[idx], font, scaling, 1)[0]
|
||||
for idx in range(len(texts))]
|
||||
|
|
@ -1002,7 +1012,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Time-lapse output set to '%s'", self._output_file)
|
||||
|
||||
# Rewrite paths to pull from the training images so mask and face data can be accessed
|
||||
images = dict()
|
||||
images = {}
|
||||
for side, input_ in zip(("a", "b"), (input_a, input_b)):
|
||||
training_path = os.path.dirname(self._image_paths[side][0])
|
||||
images[side] = [os.path.join(training_path, os.path.basename(pth))
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
-r _requirements_base.txt
|
||||
tensorflow>=2.2.0,<2.7.0
|
||||
tensorflow>=2.2.0,<2.9.0
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
-r _requirements_base.txt
|
||||
tensorflow-gpu>=2.2.0,<2.7.0
|
||||
tensorflow-gpu>=2.2.0,<2.9.0
|
||||
|
|
|
|||
163
setup.py
163
setup.py
|
|
@ -19,7 +19,7 @@ INSTALL_FAILED = False
|
|||
# Tensorflow builds available from pypi
|
||||
TENSORFLOW_REQUIREMENTS = {">=2.2.0,<2.4.0": ["10.1", "7.6"],
|
||||
">=2.4.0,<2.5.0": ["11.0", "8.0"],
|
||||
">=2.5.0,<2.7.0": ["11.2", "8.1"]}
|
||||
">=2.5.0,<2.9.0": ["11.2", "8.1"]}
|
||||
# Mapping of Python packages to their conda names if different from pip or in non-default channel
|
||||
CONDA_MAPPING = {
|
||||
# "opencv-python": ("opencv", "conda-forge"), # Periodic issues with conda-forge opencv
|
||||
|
|
@ -43,9 +43,9 @@ class Environment():
|
|||
self.enable_amd = False
|
||||
self.enable_docker = False
|
||||
self.enable_cuda = False
|
||||
self.required_packages = list()
|
||||
self.missing_packages = list()
|
||||
self.conda_missing_packages = list()
|
||||
self.required_packages = []
|
||||
self.missing_packages = []
|
||||
self.conda_missing_packages = []
|
||||
|
||||
self.process_arguments()
|
||||
self.check_permission()
|
||||
|
|
@ -54,6 +54,7 @@ class Environment():
|
|||
self.output_runtime_info()
|
||||
self.check_pip()
|
||||
self.upgrade_pip()
|
||||
self.set_ld_library_path()
|
||||
|
||||
self.installed_packages = self.get_installed_packages()
|
||||
self.installed_packages.update(self.get_installed_conda_packages())
|
||||
|
|
@ -104,7 +105,7 @@ class Environment():
|
|||
args = [arg for arg in sys.argv] # pylint:disable=unnecessary-comprehension
|
||||
if self.updater:
|
||||
from lib.utils import get_backend # pylint:disable=import-outside-toplevel
|
||||
args.append("--{}".format(get_backend()))
|
||||
args.append(f"--{get_backend()}")
|
||||
|
||||
for arg in args:
|
||||
if arg == "--installer":
|
||||
|
|
@ -124,11 +125,11 @@ class Environment():
|
|||
suffix = "cpu.txt"
|
||||
req_files = ["_requirements_base.txt", f"requirements_{suffix}"]
|
||||
pypath = os.path.dirname(os.path.realpath(__file__))
|
||||
requirements = list()
|
||||
git_requirements = list()
|
||||
requirements = []
|
||||
git_requirements = []
|
||||
for req_file in req_files:
|
||||
requirements_file = os.path.join(pypath, req_file)
|
||||
with open(requirements_file) as req:
|
||||
with open(requirements_file, encoding="utf8") as req:
|
||||
for package in req.readlines():
|
||||
package = package.strip()
|
||||
# parse_requirements can't handle git dependencies, so extract and then
|
||||
|
|
@ -157,15 +158,14 @@ class Environment():
|
|||
if not self.updater:
|
||||
self.output.info("The tool provides tips for installation\n"
|
||||
"and installs required python packages")
|
||||
self.output.info("Setup in %s %s" % (self.os_version[0], self.os_version[1]))
|
||||
self.output.info(f"Setup in {self.os_version[0]} {self.os_version[1]}")
|
||||
if not self.updater and not self.os_version[0] in ["Windows", "Linux", "Darwin"]:
|
||||
self.output.error("Your system %s is not supported!" % self.os_version[0])
|
||||
self.output.error(f"Your system {self.os_version[0]} is not supported!")
|
||||
sys.exit(1)
|
||||
|
||||
def check_python(self):
|
||||
""" Check python and virtual environment status """
|
||||
self.output.info("Installed Python: {0} {1}".format(self.py_version[0],
|
||||
self.py_version[1]))
|
||||
self.output.info(f"Installed Python: {self.py_version[0]} {self.py_version[1]}")
|
||||
if not (self.py_version[0].split(".")[0] == "3"
|
||||
and self.py_version[0].split(".")[1] in ("7", "8")
|
||||
and self.py_version[1] == "64bit") and not self.updater:
|
||||
|
|
@ -179,7 +179,7 @@ class Environment():
|
|||
self.output.info("Running in Conda")
|
||||
if self.is_virtualenv:
|
||||
self.output.info("Running in a Virtual Environment")
|
||||
self.output.info("Encoding: {}".format(self.encoding))
|
||||
self.output.info(f"Encoding: {self.encoding}")
|
||||
|
||||
def check_pip(self):
|
||||
""" Check installed pip version """
|
||||
|
|
@ -201,17 +201,16 @@ class Environment():
|
|||
if not self.is_admin and not self.is_virtualenv:
|
||||
pipexe.append("--user")
|
||||
pipexe.append("pip")
|
||||
run(pipexe)
|
||||
run(pipexe, check=True)
|
||||
import pip # pylint:disable=import-outside-toplevel
|
||||
pip_version = pip.__version__
|
||||
self.output.info("Installed pip: {}".format(pip_version))
|
||||
self.output.info(f"Installed pip: {pip_version}")
|
||||
|
||||
def get_installed_packages(self):
|
||||
""" Get currently installed packages """
|
||||
installed_packages = dict()
|
||||
chk = Popen("\"{}\" -m pip freeze".format(sys.executable),
|
||||
shell=True, stdout=PIPE)
|
||||
installed = chk.communicate()[0].decode(self.encoding).splitlines()
|
||||
installed_packages = {}
|
||||
with Popen(f"\"{sys.executable}\" -m pip freeze", shell=True, stdout=PIPE) as chk:
|
||||
installed = chk.communicate()[0].decode(self.encoding).splitlines()
|
||||
|
||||
for pkg in installed:
|
||||
if "==" not in pkg:
|
||||
|
|
@ -227,7 +226,7 @@ class Environment():
|
|||
chk = os.popen("conda list").read()
|
||||
installed = [re.sub(" +", " ", line.strip())
|
||||
for line in chk.splitlines() if not line.startswith("#")]
|
||||
retval = dict()
|
||||
retval = {}
|
||||
for pkg in installed:
|
||||
item = pkg.split(" ")
|
||||
retval[item[0]] = item[1]
|
||||
|
|
@ -253,7 +252,7 @@ class Environment():
|
|||
# that corresponds to the installed Cuda/cuDNN versions
|
||||
self.required_packages = [pkg for pkg in self.required_packages
|
||||
if not pkg.startswith("tensorflow-gpu")]
|
||||
tf_ver = "tensorflow-gpu{}".format(tf_ver)
|
||||
tf_ver = f"tensorflow-gpu{tf_ver}"
|
||||
self.required_packages.append(tf_ver)
|
||||
return
|
||||
|
||||
|
|
@ -262,13 +261,12 @@ class Environment():
|
|||
"Tensorflow currently has no official prebuild for your CUDA, cuDNN "
|
||||
"combination.\nEither install a combination that Tensorflow supports or "
|
||||
"build and install your own tensorflow-gpu.\r\n"
|
||||
"CUDA Version: {}\r\n"
|
||||
"cuDNN Version: {}\r\n"
|
||||
f"CUDA Version: {self.cuda_version}\r\n"
|
||||
f"cuDNN Version: {self.cudnn_version}\r\n"
|
||||
"Help:\n"
|
||||
"Building Tensorflow: https://www.tensorflow.org/install/install_sources\r\n"
|
||||
"Tensorflow supported versions: "
|
||||
"https://www.tensorflow.org/install/source#tested_build_configurations".format(
|
||||
self.cuda_version, self.cudnn_version))
|
||||
"https://www.tensorflow.org/install/source#tested_build_configurations")
|
||||
|
||||
custom_tf = input("Location of custom tensorflow-gpu wheel (leave "
|
||||
"blank to manually install): ")
|
||||
|
|
@ -277,9 +275,9 @@ class Environment():
|
|||
|
||||
custom_tf = os.path.realpath(os.path.expanduser(custom_tf))
|
||||
if not os.path.isfile(custom_tf):
|
||||
self.output.error("{} not found".format(custom_tf))
|
||||
self.output.error(f"{custom_tf} not found")
|
||||
elif os.path.splitext(custom_tf)[1] != ".whl":
|
||||
self.output.error("{} is not a valid pip wheel".format(custom_tf))
|
||||
self.output.error(f"{custom_tf} is not a valid pip wheel")
|
||||
elif custom_tf:
|
||||
self.required_packages.append(custom_tf)
|
||||
|
||||
|
|
@ -294,9 +292,57 @@ class Environment():
|
|||
config = {"backend": backend}
|
||||
pypath = os.path.dirname(os.path.realpath(__file__))
|
||||
config_file = os.path.join(pypath, "config", ".faceswap")
|
||||
with open(config_file, "w") as cnf:
|
||||
with open(config_file, "w", encoding="utf8") as cnf:
|
||||
json.dump(config, cnf)
|
||||
self.output.info("Faceswap config written to: {}".format(config_file))
|
||||
self.output.info(f"Faceswap config written to: {config_file}")
|
||||
|
||||
def set_ld_library_path(self):
|
||||
""" Update the LD_LIBRARY_PATH environment variable when activating a conda environment
|
||||
and revert it when deactivating.
|
||||
|
||||
Notes
|
||||
-----
|
||||
From Tensorflow 2.7, installing Cuda Toolkit from conda-forge and tensorflow from pip
|
||||
causes tensorflow to not be able to locate shared libs and hence not use the GPU.
|
||||
We update the environment variable for all instances using Conda as it shouldn't hurt
|
||||
anything and may help avoid conflicts with globally installed Cuda
|
||||
"""
|
||||
if not self.is_conda or not self.enable_cuda:
|
||||
return
|
||||
|
||||
if self.os_version[0] == "Windows":
|
||||
return
|
||||
|
||||
conda_prefix = os.environ["CONDA_PREFIX"]
|
||||
activate_folder = os.path.join(conda_prefix, "etc", "conda", "activate.d")
|
||||
deactivate_folder = os.path.join(conda_prefix, "etc", "conda", "deactivate.d")
|
||||
|
||||
os.makedirs(activate_folder, exist_ok=True)
|
||||
os.makedirs(deactivate_folder, exist_ok=True)
|
||||
|
||||
activate_script = os.path.join(conda_prefix, activate_folder, f"env_vars.sh")
|
||||
deactivate_script = os.path.join(conda_prefix, deactivate_folder, f"env_vars.sh")
|
||||
|
||||
if os.path.isfile(activate_script):
|
||||
# Only create file if it does not already exist. There may be instances where people
|
||||
# have created their own scripts, but these should be few and far between and those
|
||||
# people should already know what they are doing.
|
||||
return
|
||||
|
||||
conda_libs = os.path.join(conda_prefix, "lib")
|
||||
shebang = "#!/bin/sh\n\n"
|
||||
|
||||
with open(activate_script, "w", encoding="utf8") as afile:
|
||||
afile.write(f"{shebang}")
|
||||
afile.write("export OLD_LD_LIBRARY_PATH=${LD_LIBRARY_PATH}\n")
|
||||
afile.write(f"export LD_LIBRARY_PATH='{conda_libs}':${{LD_LIBRARY_PATH}}\n")
|
||||
|
||||
with open(deactivate_script, "w", encoding="utf8") as afile:
|
||||
afile.write(f"{shebang}")
|
||||
afile.write("export LD_LIBRARY_PATH=${OLD_LD_LIBRARY_PATH}\n")
|
||||
afile.write("unset OLD_LD_LIBRARY_PATH\n")
|
||||
|
||||
self.output.info(f"Cuda search path set to '{conda_libs}'")
|
||||
|
||||
|
||||
class Output():
|
||||
|
|
@ -324,14 +370,14 @@ class Output():
|
|||
""" Format INFO Text """
|
||||
trm = "INFO "
|
||||
if self.term_support_color:
|
||||
trm = "{}INFO {} ".format(self.green, self.default_color)
|
||||
trm = f"{self.green}INFO {self.default_color} "
|
||||
print(trm + self.__indent_text_block(text))
|
||||
|
||||
def warning(self, text):
|
||||
""" Format WARNING Text """
|
||||
trm = "WARNING "
|
||||
if self.term_support_color:
|
||||
trm = "{}WARNING{} ".format(self.yellow, self.default_color)
|
||||
trm = f"{self.yellow}WARNING{self.default_color} "
|
||||
print(trm + self.__indent_text_block(text))
|
||||
|
||||
def error(self, text):
|
||||
|
|
@ -339,7 +385,7 @@ class Output():
|
|||
global INSTALL_FAILED # pylint:disable=global-statement
|
||||
trm = "ERROR "
|
||||
if self.term_support_color:
|
||||
trm = "{}ERROR {} ".format(self.red, self.default_color)
|
||||
trm = f"{self.red}ERROR {self.default_color} "
|
||||
print(trm + self.__indent_text_block(text))
|
||||
INSTALL_FAILED = True
|
||||
|
||||
|
|
@ -471,8 +517,8 @@ class CudaCheck(): # pylint:disable=too-few-public-methods
|
|||
Initially just calls `nvcc -V` to get the installed version of Cuda currently in use.
|
||||
If this fails, drills down to more OS specific checking methods.
|
||||
"""
|
||||
chk = Popen("nvcc -V", shell=True, stdout=PIPE, stderr=PIPE)
|
||||
stdout, stderr = chk.communicate()
|
||||
with Popen("nvcc -V", shell=True, stdout=PIPE, stderr=PIPE) as chk:
|
||||
stdout, stderr = chk.communicate()
|
||||
if not stderr:
|
||||
version = re.search(r".*release (?P<cuda>\d+\.\d+)",
|
||||
stdout.decode(locale.getpreferredencoding()))
|
||||
|
|
@ -522,7 +568,7 @@ class CudaCheck(): # pylint:disable=too-few-public-methods
|
|||
if not cudnn_checkfile:
|
||||
return
|
||||
found = 0
|
||||
with open(cudnn_checkfile, "r") as ofile:
|
||||
with open(cudnn_checkfile, "r", encoding="utf8") as ofile:
|
||||
for line in ofile:
|
||||
if line.lower().startswith("#define cudnn_major"):
|
||||
major = line[line.rfind(" ") + 1:].strip()
|
||||
|
|
@ -551,7 +597,7 @@ class CudaCheck(): # pylint:disable=too-few-public-methods
|
|||
chk = os.popen("ldconfig -p | grep -P \"libcudnn.so.\\d+\" | head -n 1").read()
|
||||
chk = chk.strip().replace("libcudnn.so.", "")
|
||||
if not chk:
|
||||
return list()
|
||||
return []
|
||||
|
||||
cudnn_vers = chk[0]
|
||||
header_files = [f"cudnn_v{cudnn_vers}.h"] + self._cudnn_header_files
|
||||
|
|
@ -572,7 +618,7 @@ class CudaCheck(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
# TODO A more reliable way of getting the windows location
|
||||
if not self.cuda_path:
|
||||
return list()
|
||||
return []
|
||||
scandir = os.path.join(self.cuda_path, "include")
|
||||
cudnn_checkfiles = [os.path.join(scandir, header) for header in self._cudnn_header_files]
|
||||
return cudnn_checkfiles
|
||||
|
|
@ -701,7 +747,7 @@ class Install():
|
|||
channel = None if len(pkg) != 2 else pkg[1]
|
||||
pkg = pkg[0]
|
||||
if version:
|
||||
pkg = "{}{}".format(pkg, ",".join("".join(spec) for spec in version))
|
||||
pkg = f"{pkg}{','.join(''.join(spec) for spec in version)}"
|
||||
if self.env.is_conda and not pkg.startswith("git"):
|
||||
if pkg.startswith("tensorflow-gpu"):
|
||||
# From TF 2.4 onwards, Anaconda Tensorflow becomes a mess. The version of 2.5
|
||||
|
|
@ -760,13 +806,14 @@ class Install():
|
|||
package = f"\"{package}\""
|
||||
condaexe.append(package)
|
||||
|
||||
self.output.info("Installing {}".format(package.replace("\"", "")))
|
||||
clean_pkg = package.replace("\"", "")
|
||||
self.output.info(f"Installing {clean_pkg}")
|
||||
shell = self.env.os_version[0] == "Windows"
|
||||
try:
|
||||
if verbose:
|
||||
run(condaexe, check=True, shell=shell)
|
||||
else:
|
||||
with open(os.devnull, "w") as devnull:
|
||||
with open(os.devnull, "w", encoding="utf8") as devnull:
|
||||
run(condaexe, stdout=devnull, stderr=devnull, check=True, shell=shell)
|
||||
except CalledProcessError:
|
||||
if not conda_only:
|
||||
|
|
@ -809,14 +856,16 @@ class Install():
|
|||
pkgs = ["cudatoolkit", "cudnn"]
|
||||
shell = self.env.os_version[0] == "Windows"
|
||||
for pkg in pkgs:
|
||||
chk = Popen(condaexe + [pkg], shell=shell, stdout=PIPE)
|
||||
available = [line.split()
|
||||
for line in chk.communicate()[0].decode(self.env.encoding).splitlines()
|
||||
if line.startswith(pkg)]
|
||||
compatible = [req for req in available
|
||||
if (pkg == "cudatoolkit" and req[1].startswith(versions[0]))
|
||||
or (pkg == "cudnn" and versions[0] in req[2]
|
||||
and req[1].startswith(versions[1]))]
|
||||
with Popen(condaexe + [pkg], shell=shell, stdout=PIPE) as chk:
|
||||
available = [line.split()
|
||||
for line
|
||||
in chk.communicate()[0].decode(self.env.encoding).splitlines()
|
||||
if line.startswith(pkg)]
|
||||
compatible = [req for req in available
|
||||
if (pkg == "cudatoolkit" and req[1].startswith(versions[0]))
|
||||
or (pkg == "cudnn" and versions[0] in req[2]
|
||||
and req[1].startswith(versions[1]))]
|
||||
|
||||
candidate = "==".join(sorted(compatible, key=lambda x: x[1])[-1][:2])
|
||||
self.conda_installer(candidate, verbose=True, conda_only=True)
|
||||
|
||||
|
|
@ -828,6 +877,8 @@ class Tips():
|
|||
|
||||
def docker_no_cuda(self):
|
||||
""" Output Tips for Docker without Cuda """
|
||||
|
||||
path = os.path.dirname(os.path.realpath(__file__))
|
||||
self.output.info(
|
||||
"1. Install Docker\n"
|
||||
"https://www.docker.com/community-edition\n\n"
|
||||
|
|
@ -837,7 +888,7 @@ class Tips():
|
|||
"# without GUI\n"
|
||||
"docker run -tid -p 8888:8888 \\ \n"
|
||||
"\t--hostname deepfakes-cpu --name deepfakes-cpu \\ \n"
|
||||
"\t-v {path}:/srv \\ \n"
|
||||
f"\t-v {path}:/srv \\ \n"
|
||||
"\tdeepfakes-cpu\n\n"
|
||||
"# with gui. tools.py gui working.\n"
|
||||
"## enable local access to X11 server\n"
|
||||
|
|
@ -845,7 +896,7 @@ class Tips():
|
|||
"## create container\n"
|
||||
"nvidia-docker run -tid -p 8888:8888 \\ \n"
|
||||
"\t--hostname deepfakes-cpu --name deepfakes-cpu \\ \n"
|
||||
"\t-v {path}:/srv \\ \n"
|
||||
f"\t-v {path}:/srv \\ \n"
|
||||
"\t-v /tmp/.X11-unix:/tmp/.X11-unix \\ \n"
|
||||
"\t-e DISPLAY=unix$DISPLAY \\ \n"
|
||||
"\t-e AUDIO_GID=`getent group audio | cut -d: -f3` \\ \n"
|
||||
|
|
@ -854,12 +905,13 @@ class Tips():
|
|||
"\t-e UID=`id -u` \\ \n"
|
||||
"\tdeepfakes-cpu \n\n"
|
||||
"4. Open a new terminal to run faceswap.py in /srv\n"
|
||||
"docker exec -it deepfakes-cpu bash".format(
|
||||
path=os.path.dirname(os.path.realpath(__file__))))
|
||||
"docker exec -it deepfakes-cpu bash")
|
||||
self.output.info("That's all you need to do with a docker. Have fun.")
|
||||
|
||||
def docker_cuda(self):
|
||||
""" Output Tips for Docker wit Cuda"""
|
||||
|
||||
path = os.path.dirname(os.path.realpath(__file__))
|
||||
self.output.info(
|
||||
"1. Install Docker\n"
|
||||
"https://www.docker.com/community-edition\n\n"
|
||||
|
|
@ -873,7 +925,7 @@ class Tips():
|
|||
"# without gui \n"
|
||||
"docker run -tid -p 8888:8888 \\ \n"
|
||||
"\t--hostname deepfakes-gpu --name deepfakes-gpu \\ \n"
|
||||
"\t-v {path}:/srv \\ \n"
|
||||
f"\t-v {path}:/srv \\ \n"
|
||||
"\tdeepfakes-gpu\n\n"
|
||||
"# with gui.\n"
|
||||
"## enable local access to X11 server\n"
|
||||
|
|
@ -883,7 +935,7 @@ class Tips():
|
|||
"## create container\n"
|
||||
"nvidia-docker run -tid -p 8888:8888 \\ \n"
|
||||
"\t--hostname deepfakes-gpu --name deepfakes-gpu \\ \n"
|
||||
"\t-v {path}:/srv \\ \n"
|
||||
f"\t-v {path}:/srv \\ \n"
|
||||
"\t-v /tmp/.X11-unix:/tmp/.X11-unix \\ \n"
|
||||
"\t-e DISPLAY=unix$DISPLAY \\ \n"
|
||||
"\t-e AUDIO_GID=`getent group audio | cut -d: -f3` \\ \n"
|
||||
|
|
@ -892,8 +944,7 @@ class Tips():
|
|||
"\t-e UID=`id -u` \\ \n"
|
||||
"\tdeepfakes-gpu\n\n"
|
||||
"6. Open a new terminal to interact with the project\n"
|
||||
"docker exec deepfakes-gpu python /srv/faceswap.py gui\n".format(
|
||||
path=os.path.dirname(os.path.realpath(__file__))))
|
||||
"docker exec deepfakes-gpu python /srv/faceswap.py gui\n")
|
||||
|
||||
def macos(self):
|
||||
""" Output Tips for macOS"""
|
||||
|
|
|
|||
|
|
@ -70,41 +70,3 @@ def test_loss_wrapper(loss_func):
|
|||
else:
|
||||
output = output.numpy()
|
||||
assert output.dtype == "float32" and not np.isnan(output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()])
|
||||
def test_dssim_channels_last(dummy): # pylint:disable=unused-argument
|
||||
""" Basic test for DSSIM Loss """
|
||||
prev_data = K.image_data_format()
|
||||
K.set_image_data_format('channels_last')
|
||||
for input_dim, kernel_size in zip([32, 33], [2, 3]):
|
||||
input_shape = [input_dim, input_dim, 3]
|
||||
var_x = np.random.random_sample(4 * input_dim * input_dim * 3)
|
||||
var_x = var_x.reshape([4] + input_shape)
|
||||
var_y = np.random.random_sample(4 * input_dim * input_dim * 3)
|
||||
var_y = var_y.reshape([4] + input_shape)
|
||||
|
||||
model = Sequential()
|
||||
model.add(Conv2D(32, (3, 3), padding='same', input_shape=input_shape,
|
||||
activation='relu'))
|
||||
model.add(Conv2D(3, (3, 3), padding='same', input_shape=input_shape,
|
||||
activation='relu'))
|
||||
adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
|
||||
model.compile(loss=losses.DSSIMObjective(kernel_size=kernel_size),
|
||||
metrics=['mse'],
|
||||
optimizer=adam)
|
||||
model.fit(var_x, var_y, batch_size=2, epochs=1, shuffle='batch')
|
||||
|
||||
# Test same
|
||||
x_1 = K.constant(var_x, 'float32')
|
||||
x_2 = K.constant(var_x, 'float32')
|
||||
dssim = losses.DSSIMObjective(kernel_size=kernel_size)
|
||||
assert_allclose(0.0, K.eval(dssim(x_1, x_2)), atol=1e-4)
|
||||
|
||||
# Test opposite
|
||||
x_1 = K.zeros([4] + input_shape)
|
||||
x_2 = K.ones([4] + input_shape)
|
||||
dssim = losses.DSSIMObjective(kernel_size=kernel_size)
|
||||
assert_allclose(0.5, K.eval(dssim(x_1, x_2)), atol=1e-4)
|
||||
|
||||
K.set_image_data_format(prev_data)
|
||||
|
|
|
|||
|
|
@ -76,11 +76,11 @@ def _test_optimizer(optimizer, target=0.75):
|
|||
@pytest.mark.parametrize("dummy", [None], ids=[get_backend().upper()])
|
||||
def test_adam(dummy): # pylint:disable=unused-argument
|
||||
""" Test for custom Adam optimizer """
|
||||
_test_optimizer(k_optimizers.Adam(), target=0.5)
|
||||
_test_optimizer(k_optimizers.Adam(decay=1e-3), target=0.5)
|
||||
_test_optimizer(k_optimizers.Adam(), target=0.45)
|
||||
_test_optimizer(k_optimizers.Adam(decay=1e-3), target=0.45)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dummy", [None], ids=[get_backend().upper()])
|
||||
def test_adabelief(dummy): # pylint:disable=unused-argument
|
||||
""" Test for custom Adam optimizer """
|
||||
_test_optimizer(optimizers.AdaBelief(), target=0.5)
|
||||
_test_optimizer(optimizers.AdaBelief(), target=0.45)
|
||||
|
|
|
|||
|
|
@ -25,5 +25,6 @@ def test_backend(dummy): # pylint:disable=unused-argument
|
|||
def test_keras(dummy): # pylint:disable=unused-argument
|
||||
""" Sanity check to ensure that tensorflow keras is being used for CPU and standard
|
||||
keras for AMD. """
|
||||
assert ((_BACKEND == "cpu" and keras.__version__ in ("2.3.0-tf", "2.4.0")) or
|
||||
assert ((_BACKEND == "cpu" and keras.__version__ in ("2.3.0-tf", "2.4.0",
|
||||
"2.6.0", "2.7.0", "2.8.0")) or
|
||||
(_BACKEND == "amd" and keras.__version__ == "2.2.4"))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user