Add Laplacian Pyramid Loss

This commit is contained in:
torzdf 2022-06-18 02:29:19 +01:00
parent 04337e0c5e
commit d9c84a5f9f
15 changed files with 409 additions and 108 deletions

View File

@ -11,12 +11,12 @@ from ._base import set_exclude_devices # noqa
backend = get_backend()
if backend == "nvidia" and platform.system().lower() == "darwin":
from .nvidia_apple import NvidiaAppleStats as GPUStats # noqa
from .nvidia_apple import NvidiaAppleStats as GPUStats # type:ignore # noqa
elif backend == "nvidia":
from .nvidia import NvidiaStats as GPUStats # noqa
from .nvidia import NvidiaStats as GPUStats # type:ignore # noqa
elif backend == "amd":
from .amd import AMDStats as GPUStats, setup_plaidml # noqa
from .amd import AMDStats as GPUStats, setup_plaidml # type:ignore # noqa
elif backend == "apple_silicon":
from .apple_silicon import AppleSiliconStats as GPUStats # noqa
from .apple_silicon import AppleSiliconStats as GPUStats # type:ignore # noqa
elif backend == "cpu":
from .cpu import CPUStats as GPUStats # noqa
from .cpu import CPUStats as GPUStats # type:ignore # noqa

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
""" Parent class for obtaining Stats for various GPU/TPU backends. All GPU Stats should inherit
from the :class:`GPUStats` class contained here. """
from the :class:`_GPUStats` class contained here. """
import logging
import os
@ -23,7 +23,7 @@ class GPUInfo(TypedDict):
vram: List[int]
driver: str
devices: List[str]
devices_active: List[str]
devices_active: List[int]
class BiggestGPUInfo(TypedDict):
@ -50,7 +50,7 @@ def set_exclude_devices(devices: List[int]) -> None:
_EXCLUDE_DEVICES.extend(devices)
class GPUStats():
class _GPUStats():
""" Parent class for returning information of GPUs used. """
def __init__(self, log: bool = True) -> None:
@ -67,8 +67,8 @@ class GPUStats():
self._handles: list = self._get_handles()
self._driver: str = self._get_driver()
self._device_names: List[str] = self._get_device_names()
self._vram: List[float] = self._get_vram()
self._vram_free: List[float] = self._get_free_vram()
self._vram: List[int] = self._get_vram()
self._vram_free: List[int] = self._get_free_vram()
if get_backend() != "cpu" and not self._active_devices:
self._log("warning", "No GPU detected")
@ -164,8 +164,8 @@ class GPUStats():
devices = [idx for idx in range(self._device_count) if idx not in _EXCLUDE_DEVICES]
env_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if env_devices:
env_devices = [int(i) for i in env_devices.split(",")]
devices = [idx for idx in devices if idx in env_devices]
new_devices = [int(i) for i in env_devices.split(",")]
devices = [idx for idx in devices if idx in new_devices]
self._log("debug", f"Active GPU Devices: {devices}")
return devices
@ -202,7 +202,7 @@ class GPUStats():
"""
raise NotImplementedError()
def _get_vram(self) -> List[float]:
def _get_vram(self) -> List[int]:
""" Override to obtain the total VRAM in Megabytes for each connected GPU.
Returns
@ -213,8 +213,8 @@ class GPUStats():
"""
raise NotImplementedError()
def _get_free_vram(self) -> List[float]:
""" Override to obrain the amount of VRAM that is available, in Megabytes, for each
def _get_free_vram(self) -> List[int]:
""" Override to obtain the amount of VRAM that is available, in Megabytes, for each
connected GPU.
Returns

View File

@ -9,7 +9,7 @@ from typing import List, Optional
import plaidml
from ._base import GPUStats, _EXCLUDE_DEVICES
from ._base import _GPUStats, _EXCLUDE_DEVICES
_PLAIDML_INITIALIZED: bool = False
@ -40,7 +40,7 @@ def setup_plaidml(log_level: str, exclude_devices: List[int]) -> None:
logger.info("Successfully set up for PlaidML")
class AMDStats(GPUStats):
class AMDStats(_GPUStats):
""" Holds information and statistics about AMD GPU(s) available on the currently
running system.
@ -104,9 +104,9 @@ class AMDStats(GPUStats):
return retval
@property
def _all_vram(self) -> List[float]:
def _all_vram(self) -> List[int]:
""" list: The VRAM of each GPU device that PlaidML has discovered. """
return [int(device.get("globalMemSize", 0)) / (1024 * 1024)
return [int(device.get("globalMemSize", 0) / (1024 * 1024))
for device in self._device_details]
@property
@ -159,7 +159,7 @@ class AMDStats(GPUStats):
self._log("debug", "Setting PlaidML Default Logger")
plaidml.DEFAULT_LOG_HANDLER = logging.getLogger("plaidml_root")
plaidml.DEFAULT_LOG_HANDLER.propagate = 0
plaidml.DEFAULT_LOG_HANDLER.propagate = False
numeric_level = getattr(logging, self._log_level, None)
if numeric_level < 10: # DEBUG Logging
@ -344,8 +344,8 @@ class AMDStats(GPUStats):
str
The current AMD GPU driver versions
"""
drivers = [device.get("driverVersion", "No Driver Found")
for device in self._device_details]
drivers = "|".join([device.get("driverVersion", "No Driver Found")
for device in self._device_details])
self._log("debug", f"GPU Drivers: {drivers}")
return drivers
@ -361,7 +361,7 @@ class AMDStats(GPUStats):
self._log("debug", f"GPU Devices: {names}")
return names
def _get_vram(self) -> List[float]:
def _get_vram(self) -> List[int]:
""" Obtain the VRAM in Megabytes for each connected AMD GPU as identified in
:attr:`_handles`.
@ -374,7 +374,7 @@ class AMDStats(GPUStats):
self._log("debug", f"GPU VRAM: {vram}")
return vram
def _get_free_vram(self) -> List[float]:
def _get_free_vram(self) -> List[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected AMD
GPU.

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
""" Collects and returns Information on available Apple Silicon SoCs in Apple Macs. """
from typing import List, Optional
from typing import Any, List
import os
import psutil
@ -8,13 +8,13 @@ import tensorflow as tf
from lib.utils import FaceswapError
from ._base import GPUStats
from ._base import _GPUStats
_METAL_INITIALIZED: bool = False
class AppleSiliconStats(GPUStats):
class AppleSiliconStats(_GPUStats):
""" Holds information and statistics about Apple Silicon SoC(s) available on the currently
running Apple system.
@ -35,7 +35,7 @@ class AppleSiliconStats(GPUStats):
"""
def __init__(self, log: bool = True) -> None:
# Following attribute set in :func:``_initialize``
self._tf_devices: Optional(List[str]) = None
self._tf_devices: List[Any] = []
super().__init__(log=log)
@ -155,7 +155,7 @@ class AppleSiliconStats(GPUStats):
self._log("debug", f"GPU Devices: {names}")
return names
def _get_vram(self) -> List[float]:
def _get_vram(self) -> List[int]:
""" Obtain the VRAM in Megabytes for each available Apple Silicon SoC(s) as identified in
:attr:`_handles`.
@ -170,12 +170,12 @@ class AppleSiliconStats(GPUStats):
list
The RAM in Megabytes for each available Apple Silicon SoC
"""
vram = [(psutil.virtual_memory().total / self._device_count) / (1024 * 1024)
vram = [int((psutil.virtual_memory().total / self._device_count) / (1024 * 1024))
for _ in range(self._device_count)]
self._log("debug", f"SoC RAM: {vram}")
return vram
def _get_free_vram(self) -> List[float]:
def _get_free_vram(self) -> List[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each available Apple
Silicon SoC.
@ -185,7 +185,7 @@ class AppleSiliconStats(GPUStats):
List of `float`s containing the amount of RAM available, in Megabytes, for each
available SoC as corresponding to the values in :attr:`_handles
"""
vram = [(psutil.virtual_memory().available / self._device_count) / (1024 * 1024)
vram = [int((psutil.virtual_memory().available / self._device_count) / (1024 * 1024))
for _ in range(self._device_count)]
self._log("debug", f"SoC RAM free: {vram}")
return vram

View File

@ -4,19 +4,19 @@
from typing import List
from ._base import GPUStats
from ._base import _GPUStats
class CPUStats(GPUStats):
class CPUStats(_GPUStats):
""" Holds information and statistics about the CPU on the currently running system.
Notes
-----
The information held here is not useful, but GPUStats is dynamically imported depending on the
backend used, so we need to make sure this class is available for Faceswap run on the CPU
The information held here is not useful, but _GPUStats is dynamically imported depending on
the backend used, so we need to make sure this class is available for Faceswap run on the CPU
Backend.
The base :class:`GPUStats` handles the dummying in of information when no GPU is detected.
The base :class:`_GPUStats` handles the dummying in of information when no GPU is detected.
Parameters
----------
@ -49,7 +49,7 @@ class CPUStats(GPUStats):
list
An empty list for CPU Backends
"""
handles = []
handles: list = []
self._log("debug", f"GPU Handles found: {len(handles)}")
return handles
@ -73,11 +73,11 @@ class CPUStats(GPUStats):
list
An empty list for CPU backends
"""
names = []
names: List[str] = []
self._log("debug", f"GPU Devices: {names}")
return names
def _get_vram(self) -> List[float]:
def _get_vram(self) -> List[int]:
""" Obtain the RAM in Megabytes for the running system.
Returns
@ -85,11 +85,11 @@ class CPUStats(GPUStats):
list
An empty list for CPU backends
"""
vram = []
vram: List[int] = []
self._log("debug", f"GPU VRAM: {vram}")
return vram
def _get_free_vram(self) -> List[float]:
def _get_free_vram(self) -> List[int]:
""" Obtain the amount of RAM that is available, in Megabytes, for the running system.
Returns
@ -97,6 +97,6 @@ class CPUStats(GPUStats):
list
An empty list for CPU backends
"""
vram = []
vram: List[int] = []
self._log("debug", f"GPU VRAM free: {vram}")
return vram

View File

@ -7,10 +7,10 @@ import pynvml
from lib.utils import FaceswapError
from ._base import GPUStats
from ._base import _GPUStats
class NvidiaStats(GPUStats):
class NvidiaStats(_GPUStats):
""" Holds information and statistics about Nvidia GPU(s) available on the currently
running system.
@ -125,7 +125,7 @@ class NvidiaStats(GPUStats):
self._log("debug", f"GPU Devices: {names}")
return names
def _get_vram(self) -> List[float]:
def _get_vram(self) -> List[int]:
""" Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in
:attr:`_handles`.
@ -139,7 +139,7 @@ class NvidiaStats(GPUStats):
self._log("debug", f"GPU VRAM: {vram}")
return vram
def _get_free_vram(self) -> List[float]:
def _get_free_vram(self) -> List[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia
GPU.

View File

@ -6,10 +6,10 @@ import pynvx
from lib.utils import FaceswapError
from ._base import GPUStats
from ._base import _GPUStats
class NvidiaAppleStats(GPUStats):
class NvidiaAppleStats(_GPUStats):
""" Holds information and statistics about Nvidia GPU(s) available on the currently
running Apple system.
@ -105,7 +105,7 @@ class NvidiaAppleStats(GPUStats):
self._log("debug", f"GPU Devices: {names}")
return names
def _get_vram(self) -> List[float]:
def _get_vram(self) -> List[int]:
""" Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in
:attr:`_handles`.
@ -120,7 +120,7 @@ class NvidiaAppleStats(GPUStats):
self._log("debug", f"GPU VRAM: {vram}")
return vram
def _get_free_vram(self) -> List[float]:
def _get_free_vram(self) -> List[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia
GPU.

View File

@ -4,7 +4,7 @@
from __future__ import absolute_import
import logging
from typing import List, Tuple
from typing import Callable, List, Tuple
import numpy as np
import plaidml
@ -621,18 +621,155 @@ class GMSDLoss(): # pylint:disable=too-few-public-methods
return output
class LaplacianPyramidLoss(): # pylint:disable=too-few-public-methods
""" Laplacian Pyramid Loss Function
Notes
-----
Channels last implementation on square images only.
Parameters
----------
max_levels: int, Optional
The max number of laplacian pyramid levels to use. Default: `5`
gaussian_size: int, Optional
The size of the gaussian kernel. Default: `5`
gaussian_sigma: float, optional
The gaussian sigma. Default: 2.0
References
----------
https://arxiv.org/abs/1707.05776
https://github.com/nathanaelbosch/generative-latent-optimization/blob/master/utils.py
"""
def __init__(self,
max_levels: int = 5,
gaussian_size: int = 5,
gaussian_sigma: float = 1.0) -> None:
self._max_levels = max_levels
self._weights = K.constant([np.power(2., -2 * idx) for idx in range(max_levels + 1)])
self._gaussian_kernel = self._get_gaussian_kernel(gaussian_size, gaussian_sigma)
self._shape: Tuple[int, ...] = ()
@classmethod
def _get_gaussian_kernel(cls, size: int, sigma: float) -> plaidml.tile.Value:
""" Obtain the base gaussian kernel for the Laplacian Pyramid.
Parameters
----------
size: int, Optional
The size of the gaussian kernel
sigma: float
The gaussian sigma
Returns
-------
:class:`plaidml.tile.Value`
The base single channel Gaussian kernel
"""
assert size % 2 == 1, ("kernel size must be uneven")
x_1 = np.linspace(- (size // 2), size // 2, size, dtype="float32")
x_1 /= np.sqrt(2)*sigma
x_2 = x_1 ** 2
kernel = np.exp(- x_2[:, None] - x_2[None, :])
kernel /= kernel.sum()
kernel = np.reshape(kernel, (size, size, 1, 1))
return K.constant(kernel)
def _conv_gaussian(self, inputs: plaidml.tile.Value) -> plaidml.tile.Value:
""" Perform Gaussian convolution on a batch of images.
Parameters
----------
inputs: :class:`plaidml.tile.Value`
The input batch of images to perform Gaussian convolution on.
Returns
-------
:class:`plaidml.tile.Value`
The convolved images
"""
channels = self._shape[-1]
gauss = K.tile(self._gaussian_kernel, (1, 1, 1, channels))
# PlaidML doesn't implement replication padding like pytorch. This is an inefficient way to
# implement it for a square guassian kernel
size = K.int_shape(self._gaussian_kernel)[1] // 2
padded_inputs = inputs
for _ in range(size):
padded_inputs = pad(padded_inputs, # noqa,pylint:disable=no-value-for-parameter,unexpected-keyword-arg
([0, 0], [1, 1], [1, 1], [0, 0]),
mode="REFLECT")
retval = K.conv2d(padded_inputs, gauss, strides=(1, 1), padding="valid")
return retval
def _get_laplacian_pyramid(self, inputs: plaidml.tile.Value) -> List[plaidml.tile.Value]:
""" Obtain the Laplacian Pyramid.
Parameters
----------
inputs: :class:`plaidml.tile.Value`
The input batch of images to run through the Laplacian Pyramid
Returns
-------
list
The tensors produced from the Laplacian Pyramid
"""
pyramid = []
current = inputs
for _ in range(self._max_levels):
gauss = self._conv_gaussian(current)
diff = current - gauss
pyramid.append(diff)
current = K.pool2d(gauss, (2, 2), strides=(2, 2), padding="valid", pool_mode="avg")
pyramid.append(current)
return pyramid
def __call__(self,
y_true: plaidml.tile.Value,
y_pred: plaidml.tile.Value) -> plaidml.tile.Value:
""" Calculate the Laplacian Pyramid Loss.
Parameters
----------
y_true: :class:`plaidml.tile.Value`
The ground truth value
y_pred: :class:`plaidml.tile.Value`
The predicted value
Returns
-------
:class: `plaidml.tile.Value`
The loss value
"""
if not self._shape:
self._shape = K.int_shape(y_pred)
pyramid_true = self._get_laplacian_pyramid(y_true)
pyramid_pred = self._get_laplacian_pyramid(y_pred)
losses = K.stack([K.sum(K.abs(ppred - ptrue)) / K.cast(K.prod(K.shape(ptrue)), "float32")
for ptrue, ppred in zip(pyramid_true, pyramid_pred)])
loss = K.sum(losses * self._weights)
return loss
class LossWrapper(): # pylint:disable=too-few-public-methods
""" A wrapper class for multiple keras losses to enable multiple weighted loss functions on a
single output and masking.
"""
def __init__(self):
def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__)
self._loss_functions = []
self._loss_weights = []
self._mask_channels = []
self._loss_functions: List[Callable] = []
self._loss_weights: List[float] = []
self._mask_channels: List[int] = []
logger.debug("Initialized: %s", self.__class__.__name__)
def add_loss(self, function, weight=1.0, mask_channel=-1):
def add_loss(self,
function,
weight: float = 1.0,
mask_channel: int = -1) -> None:
""" Add the given loss function with the given weight to the loss function chain.
Parameters
@ -651,21 +788,23 @@ class LossWrapper(): # pylint:disable=too-few-public-methods
self._loss_weights.append(weight)
self._mask_channels.append(mask_channel)
def __call__(self, y_true, y_pred):
def __call__(self,
y_true: plaidml.tile.Value,
y_pred: plaidml.tile.Value) -> plaidml.tile.Value:
""" Call the sub loss functions for the loss wrapper.
Weights are returned as the weighted sum of the chosen losses.
Parameters
----------
y_true: tensor or variable
y_true: :class:`plaidml.tile.Value`
The ground truth value
y_pred: tensor or variable
y_pred: :class:`plaidml.tile.Value`
The predicted value
Returns
-------
tensor
:class:`plaidml.tile.Value`
The final loss value
"""
loss = 0.0
@ -685,15 +824,19 @@ class LossWrapper(): # pylint:disable=too-few-public-methods
return loss
@classmethod
def _apply_mask(cls, y_true, y_pred, mask_channel, mask_prop=1.0):
def _apply_mask(cls,
y_true: plaidml.tile.Value,
y_pred: plaidml.tile.Value,
mask_channel: int,
mask_prop: float = 1.0) -> Tuple[plaidml.tile.Value, plaidml.tile.Value]:
""" Apply the mask to the input y_true and y_pred. If a mask is not required then
return the unmasked inputs.
Parameters
----------
y_true: tensor or variable
y_true: :class:`plaidml.tile.Value`
The ground truth value
y_pred: tensor or variable
y_pred: :class:`plaidml.tile.Value`
The predicted value
mask_channel: int
The channel within y_true that the required mask resides in

View File

@ -4,7 +4,7 @@
from __future__ import absolute_import
import logging
from typing import List, Tuple
from typing import Callable, List, Tuple
import numpy as np
import tensorflow as tf
@ -558,6 +558,136 @@ class GMSDLoss(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
return output
class LaplacianPyramidLoss(): # pylint:disable=too-few-public-methods
""" Laplacian Pyramid Loss Function
Notes
-----
Channels last implementation on square images only.
Parameters
----------
max_levels: int, Optional
The max number of laplacian pyramid levels to use. Default: `5`
gaussian_size: int, Optional
The size of the gaussian kernel. Default: `5`
gaussian_sigma: float, optional
The gaussian sigma. Default: 2.0
References
----------
https://arxiv.org/abs/1707.05776
https://github.com/nathanaelbosch/generative-latent-optimization/blob/master/utils.py
"""
def __init__(self,
max_levels: int = 5,
gaussian_size: int = 5,
gaussian_sigma: float = 1.0) -> None:
self._max_levels = max_levels
self._weights = K.constant([np.power(2., -2 * idx) for idx in range(max_levels + 1)])
self._gaussian_kernel = self._get_gaussian_kernel(gaussian_size, gaussian_sigma)
@classmethod
def _get_gaussian_kernel(cls, size: int, sigma: float) -> tf.Tensor:
""" Obtain the base gaussian kernel for the Laplacian Pyramid.
Parameters
----------
size: int, Optional
The size of the gaussian kernel
sigma: float
The gaussian sigma
Returns
-------
tf.Tensor
The base single channel Gaussian kernel
"""
assert size % 2 == 1, ("kernel size must be uneven")
x_1 = np.linspace(- (size // 2), size // 2, size, dtype="float32")
x_1 /= np.sqrt(2)*sigma
x_2 = x_1 ** 2
kernel = np.exp(- x_2[:, None] - x_2[None, :])
kernel /= kernel.sum()
kernel = np.reshape(kernel, (size, size, 1, 1))
return K.constant(kernel)
def _conv_gaussian(self, inputs: tf.Tensor) -> tf.Tensor:
""" Perform Gaussian convolution on a batch of images.
Parameters
----------
inputs: :class:`tf.Tensor`
The input batch of images to perform Gaussian convolution on.
Returns
-------
:class:`tf.Tensor`
The convolved images
"""
channels = K.int_shape(inputs)[-1]
gauss = K.tile(self._gaussian_kernel, (1, 1, 1, channels))
# TF doesn't implement replication padding like pytorch. This is an inefficient way to
# implement it for a square guassian kernel
size = self._gaussian_kernel.shape[1] // 2
padded_inputs = inputs
for _ in range(size):
padded_inputs = tf.pad(padded_inputs, # noqa,pylint:disable=no-value-for-parameter,unexpected-keyword-arg
([0, 0], [1, 1], [1, 1], [0, 0]),
mode="SYMMETRIC")
retval = K.conv2d(padded_inputs, gauss, strides=1, padding="valid")
return retval
def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> List[tf.Tensor]:
""" Obtain the Laplacian Pyramid.
Parameters
----------
inputs: :class:`tf.Tensor`
The input batch of images to run through the Laplacian Pyramid
Returns
-------
list
The tensors produced from the Laplacian Pyramid
"""
pyramid = []
current = inputs
for _ in range(self._max_levels):
gauss = self._conv_gaussian(current)
diff = current - gauss
pyramid.append(diff)
current = K.pool2d(gauss, (2, 2), strides=(2, 2), padding="valid", pool_mode="avg")
pyramid.append(current)
return pyramid
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
""" Calculate the Laplacian Pyramid Loss.
Parameters
----------
y_true: :class:`tf.Tensor`
The ground truth value
y_pred: :class:`tf.Tensor`
The predicted value
Returns
-------
:class: `tf.Tensor`
The loss value
"""
pyramid_true = self._get_laplacian_pyramid(y_true)
pyramid_pred = self._get_laplacian_pyramid(y_pred)
losses = K.stack([K.sum(K.abs(ppred - ptrue)) / K.cast(K.prod(K.shape(ptrue)), "float32")
for ptrue, ppred in zip(pyramid_true, pyramid_pred)])
loss = K.sum(losses * self._weights)
return loss
class LossWrapper():
""" A wrapper class for multiple keras losses to enable multiple masked weighted loss
functions on a single output.
@ -586,7 +716,7 @@ class LossWrapper():
logger.debug("Initialized: %s", self.__class__.__name__)
def add_loss(self,
function: tf.keras.losses.Loss,
function: Callable,
weight: float = 1.0,
mask_channel: int = -1) -> None:
""" Add the given loss function with the given weight to the loss function chain.

View File

@ -122,7 +122,7 @@ class Extractor():
For align/mask (2nd/3rd pass operations) the :attr:`ExtractMedia.detected_faces` should
also be populated by calling :func:`ExtractMedia.set_detected_faces`.
"""
qname = "extract{}_{}_in".format(self._instance, self._current_phase[0])
qname = f"extract{self._instance}_{self._current_phase[0]}_in"
retval = self._queues[qname]
logger.trace("%s: %s", qname, retval)
return retval
@ -193,7 +193,7 @@ class Extractor():
The batch size to use for this plugin type
"""
logger.debug("Overriding batchsize for plugin_type: %s to: %s", plugin_type, batchsize)
plugin = getattr(self, "_{}".format(plugin_type))
plugin = getattr(self, f"_{plugin_type}")
plugin.batchsize = batchsize
def launch(self):
@ -283,10 +283,10 @@ class Extractor():
@property
def _vram_per_phase(self):
""" dict: The amount of vram required for each phase in :attr:`_flow`. """
retval = dict()
retval = {}
for phase in self._flow:
plugin_type, idx = self._get_plugin_type_and_index(phase)
attr = getattr(self, "_{}".format(plugin_type))
attr = getattr(self, f"_{plugin_type}")
attr = attr[idx] if idx is not None else attr
retval[phase] = attr.vram
logger.trace(retval)
@ -322,10 +322,9 @@ class Extractor():
def _output_queue(self):
""" Return the correct output queue depending on the current phase """
if self.final_pass:
qname = "extract{}_{}_out".format(self._instance, self._final_phase)
qname = f"extract{self._instance}_{self._final_phase}_out"
else:
qname = "extract{}_{}_in".format(self._instance,
self._phases[self._phase_index + 1][0])
qname = f"extract{self._instance}_{self._phases[self._phase_index + 1][0]}_in"
retval = self._queues[qname]
logger.trace("%s: %s", qname, retval)
return retval
@ -336,7 +335,7 @@ class Extractor():
retval = []
for phase in self._flow:
plugin_type, idx = self._get_plugin_type_and_index(phase)
attr = getattr(self, "_{}".format(plugin_type))
attr = getattr(self, f"_{plugin_type}")
attr = attr[idx] if idx is not None else attr
retval.append(attr)
logger.trace("All Plugins: %s", retval)
@ -348,7 +347,7 @@ class Extractor():
retval = []
for phase in self._current_phase:
plugin_type, idx = self._get_plugin_type_and_index(phase)
attr = getattr(self, "_{}".format(plugin_type))
attr = getattr(self, f"_{plugin_type}")
retval.append(attr[idx] if idx is not None else attr)
logger.trace("Active plugins: %s", retval)
return retval
@ -362,7 +361,7 @@ class Extractor():
retval.append("detect")
if aligner is not None and aligner.lower() != "none":
retval.append("align")
retval.extend(["mask_{}".format(idx)
retval.extend([f"mask_{idx}"
for idx, mask in enumerate(masker)
if mask is not None and mask.lower() != "none"])
logger.debug("flow: %s", retval)
@ -400,9 +399,9 @@ class Extractor():
def _add_queues(self):
""" Add the required processing queues to Queue Manager """
queues = dict()
tasks = ["extract{}_{}_in".format(self._instance, phase) for phase in self._flow]
tasks.append("extract{}_{}_out".format(self._instance, self._final_phase))
queues = {}
tasks = [f"extract{self._instance}_{phase}_in" for phase in self._flow]
tasks.append(f"extract{self._instance}_{self._final_phase}_out")
for task in tasks:
# Limit queue size to avoid stacking ram
queue_manager.add_queue(task, maxsize=self._queue_size)
@ -552,17 +551,17 @@ class Extractor():
def _launch_plugin(self, phase):
""" Launch an extraction plugin """
logger.debug("Launching %s plugin", phase)
in_qname = "extract{}_{}_in".format(self._instance, phase)
in_qname = f"extract{self._instance}_{phase}_in"
if phase == self._final_phase:
out_qname = "extract{}_{}_out".format(self._instance, self._final_phase)
out_qname = f"extract{self._instance}_{self._final_phase}_out"
else:
next_phase = self._flow[self._flow.index(phase) + 1]
out_qname = "extract{}_{}_in".format(self._instance, next_phase)
out_qname = f"extract{self._instance}_{next_phase}_in"
logger.debug("in_qname: %s, out_qname: %s", in_qname, out_qname)
kwargs = dict(in_queue=self._queues[in_qname], out_queue=self._queues[out_qname])
plugin_type, idx = self._get_plugin_type_and_index(phase)
plugin = getattr(self, "_{}".format(plugin_type))
plugin = getattr(self, f"_{plugin_type}")
plugin = plugin[idx] if idx is not None else plugin
plugin.initialize(**kwargs)
plugin.start()
@ -645,7 +644,7 @@ class Extractor():
logger.debug("Remaining VRAM to allocate: %sMB", remaining)
if batchsizes != requested_batchsizes:
text = ", ".join(["{}: {}".format(plugin.__class__.__name__, batchsize)
text = ", ".join([f"{plugin.__class__.__name__}: {batchsize}"
for plugin, batchsize in zip(plugins, batchsizes)])
for plugin, batchsize in zip(plugins, batchsizes):
plugin.batchsize = batchsize
@ -726,7 +725,7 @@ class ExtractMedia():
A copy of :attr:`image` in the requested :attr:`color_format`
"""
logger.trace("Requested color format '%s' for frame '%s'", color_format, self._filename)
image = getattr(self, "_image_as_{}".format(color_format.lower()))()
image = getattr(self, f"_image_as_{color_format.lower()}")()
return image
def add_detected_faces(self, faces):

View File

@ -21,6 +21,13 @@ _LOSS_HELP = dict(
"The L_inf norm will reduce the largest individual pixel error in an image. As "
"each largest error is minimized sequentially, the overall error is improved. This loss "
"will be extremely focused on outliers."),
laploss=(
"Laplacian Pyramid Loss. Attempts to improve results by focussing on edges using "
"Laplacian Pyramids. As this loss function gives priority to edges over other low-"
"frequency information, like color, it should not be used on its own. The original "
"implementation uses this loss as a complimentary function to MSE. "
"Ref: Optimizing the Latent Space of Generative Networks "
"https://arxiv.org/abs/1707.05776"),
logcosh=(
"log(cosh(x)) acts similar to MSE for small errors and to MAE for large errors. Like "
"MSE, it is very stable and prevents overshoots when errors are near zero. Like MAE, it "

View File

@ -53,15 +53,16 @@ class Loss():
def __init__(self, config: dict) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
self._config = config
self._loss_dict = dict(mae=k_losses.mean_absolute_error,
mse=k_losses.mean_squared_error,
logcosh=k_losses.logcosh,
smooth_loss=losses.GeneralizedLoss(),
self._loss_dict = dict(gmsd=losses.GMSDLoss(),
l_inf_norm=losses.LInfNorm(),
ssim=losses.DSSIMObjective(),
laploss=losses.LaplacianPyramidLoss(),
logcosh=k_losses.logcosh,
ms_ssim=losses.MSSIMLoss(),
gmsd=losses.GMSDLoss(),
pixel_gradient_diff=losses.GradientLoss())
mae=k_losses.mean_absolute_error,
mse=k_losses.mean_squared_error,
pixel_gradient_diff=losses.GradientLoss(),
ssim=losses.DSSIMObjective(),
smooth_loss=losses.GeneralizedLoss(),)
self._mask_channels = self._get_mask_channels()
self._inputs: List[keras.layers.Layer] = []
self._names: List[str] = []

View File

@ -10,5 +10,15 @@ exclude = .git, __pycache__
ignore_missing_imports = True
[mypy-keras.*]
ignore_missing_imports = True
[mypy-psutil.*]
ignore_missing_imports = True
[mypy-plaidml.*]
ignore_missing_imports = True
[mypy-tqdm.*]
ignore_missing_imports = True
[mypy-pynvml.*]
ignore_missing_imports = True
[mypy-pynvx.*]
ignore_missing_imports = True
[mypy-cv2.*]
ignore_missing_imports = True

View File

@ -41,11 +41,18 @@ def test_loss_output(loss_func, output_shape):
assert output.dtype == "float32" and not np.any(np.isnan(output))
_LWPARAMS = [losses.GeneralizedLoss(), losses.GradientLoss(), losses.GMSDLoss(),
losses.LInfNorm(), k_losses.mean_absolute_error, k_losses.mean_squared_error,
k_losses.logcosh, losses.DSSIMObjective(), losses.MSSIMLoss()]
_LWIDS = ["GeneralizedLoss", "GradientLoss", "GMSDLoss", "LInfNorm", "mae", "mse", "logcosh",
"DSSIMObjective", "MS-SSIM"]
_LWPARAMS = [losses.DSSIMObjective(),
losses.GeneralizedLoss(),
losses.GMSDLoss(),
losses.GradientLoss(),
losses.LaplacianPyramidLoss(),
losses.LInfNorm(),
k_losses.logcosh,
k_losses.mean_absolute_error,
k_losses.mean_squared_error,
losses.MSSIMLoss()]
_LWIDS = ["DSSIMObjective", "GeneralizedLoss", "GMSDLoss", "GradientLoss", "LaplacianPyramidLoss",
"LInfNorm", "logcosh", "mae", "mse", "MS-SSIM"]
_LWIDS = [f"{loss}[{get_backend().upper()}]" for loss in _LWIDS]
@ -57,8 +64,8 @@ def test_loss_wrapper(loss_func):
pytest.skip("GMSD Loss is not currently compatible with PlaidML")
if hasattr(loss_func, "__name__") and loss_func.__name__ == "logcosh":
pytest.skip("LogCosh Loss is not currently compatible with PlaidML")
y_a = K.variable(np.random.random((2, 16, 16, 4)))
y_b = K.variable(np.random.random((2, 16, 16, 3)))
y_a = K.variable(np.random.random((2, 64, 64, 4)))
y_b = K.variable(np.random.random((2, 64, 64, 3)))
p_loss = losses.LossWrapper()
p_loss.add_loss(loss_func, 1.0, -1)
p_loss.add_loss(k_losses.mean_squared_error, 2.0, 3)

View File

@ -14,6 +14,7 @@ from urllib.request import urlretrieve
import os
from os.path import join as pathjoin, expanduser
_TRAIN_ARGS = (1, 1) if os.environ.get("FACESWAP_BACKEND", "cpu").lower() == "amd" else (4, 4)
FAIL_COUNT = 0
TEST_COUNT = 0
_COLORS = {
@ -167,18 +168,21 @@ def main():
run_test(
"Train lightweight model for 1 iteration with WTL.",
train_args(
"lightweight", pathjoin(vid_base, "model"),
pathjoin(vid_base, "faces"), extra_args="-wl"
)
)
train_args("lightweight",
pathjoin(vid_base, "model"),
pathjoin(vid_base, "faces"),
iterations=_TRAIN_ARGS[0],
batchsize=_TRAIN_ARGS[1],
extra_args="-wl"))
was_trained = run_test(
"Train lightweight model for 1 iterations WITHOUT WTL.",
train_args(
"lightweight", pathjoin(vid_base, "model"), pathjoin(vid_base, "faces")
)
)
train_args("lightweight",
pathjoin(vid_base, "model"),
pathjoin(vid_base, "faces"),
iterations=_TRAIN_ARGS[0],
batchsize=_TRAIN_ARGS[1],
extra_args="-wl"))
if was_trained:
run_test(