mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
Add Laplacian Pyramid Loss
This commit is contained in:
parent
04337e0c5e
commit
d9c84a5f9f
|
|
@ -11,12 +11,12 @@ from ._base import set_exclude_devices # noqa
|
|||
backend = get_backend()
|
||||
|
||||
if backend == "nvidia" and platform.system().lower() == "darwin":
|
||||
from .nvidia_apple import NvidiaAppleStats as GPUStats # noqa
|
||||
from .nvidia_apple import NvidiaAppleStats as GPUStats # type:ignore # noqa
|
||||
elif backend == "nvidia":
|
||||
from .nvidia import NvidiaStats as GPUStats # noqa
|
||||
from .nvidia import NvidiaStats as GPUStats # type:ignore # noqa
|
||||
elif backend == "amd":
|
||||
from .amd import AMDStats as GPUStats, setup_plaidml # noqa
|
||||
from .amd import AMDStats as GPUStats, setup_plaidml # type:ignore # noqa
|
||||
elif backend == "apple_silicon":
|
||||
from .apple_silicon import AppleSiliconStats as GPUStats # noqa
|
||||
from .apple_silicon import AppleSiliconStats as GPUStats # type:ignore # noqa
|
||||
elif backend == "cpu":
|
||||
from .cpu import CPUStats as GPUStats # noqa
|
||||
from .cpu import CPUStats as GPUStats # type:ignore # noqa
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Parent class for obtaining Stats for various GPU/TPU backends. All GPU Stats should inherit
|
||||
from the :class:`GPUStats` class contained here. """
|
||||
from the :class:`_GPUStats` class contained here. """
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -23,7 +23,7 @@ class GPUInfo(TypedDict):
|
|||
vram: List[int]
|
||||
driver: str
|
||||
devices: List[str]
|
||||
devices_active: List[str]
|
||||
devices_active: List[int]
|
||||
|
||||
|
||||
class BiggestGPUInfo(TypedDict):
|
||||
|
|
@ -50,7 +50,7 @@ def set_exclude_devices(devices: List[int]) -> None:
|
|||
_EXCLUDE_DEVICES.extend(devices)
|
||||
|
||||
|
||||
class GPUStats():
|
||||
class _GPUStats():
|
||||
""" Parent class for returning information of GPUs used. """
|
||||
|
||||
def __init__(self, log: bool = True) -> None:
|
||||
|
|
@ -67,8 +67,8 @@ class GPUStats():
|
|||
self._handles: list = self._get_handles()
|
||||
self._driver: str = self._get_driver()
|
||||
self._device_names: List[str] = self._get_device_names()
|
||||
self._vram: List[float] = self._get_vram()
|
||||
self._vram_free: List[float] = self._get_free_vram()
|
||||
self._vram: List[int] = self._get_vram()
|
||||
self._vram_free: List[int] = self._get_free_vram()
|
||||
|
||||
if get_backend() != "cpu" and not self._active_devices:
|
||||
self._log("warning", "No GPU detected")
|
||||
|
|
@ -164,8 +164,8 @@ class GPUStats():
|
|||
devices = [idx for idx in range(self._device_count) if idx not in _EXCLUDE_DEVICES]
|
||||
env_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||
if env_devices:
|
||||
env_devices = [int(i) for i in env_devices.split(",")]
|
||||
devices = [idx for idx in devices if idx in env_devices]
|
||||
new_devices = [int(i) for i in env_devices.split(",")]
|
||||
devices = [idx for idx in devices if idx in new_devices]
|
||||
self._log("debug", f"Active GPU Devices: {devices}")
|
||||
return devices
|
||||
|
||||
|
|
@ -202,7 +202,7 @@ class GPUStats():
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_vram(self) -> List[float]:
|
||||
def _get_vram(self) -> List[int]:
|
||||
""" Override to obtain the total VRAM in Megabytes for each connected GPU.
|
||||
|
||||
Returns
|
||||
|
|
@ -213,8 +213,8 @@ class GPUStats():
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_free_vram(self) -> List[float]:
|
||||
""" Override to obrain the amount of VRAM that is available, in Megabytes, for each
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
""" Override to obtain the amount of VRAM that is available, in Megabytes, for each
|
||||
connected GPU.
|
||||
|
||||
Returns
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import List, Optional
|
|||
|
||||
import plaidml
|
||||
|
||||
from ._base import GPUStats, _EXCLUDE_DEVICES
|
||||
from ._base import _GPUStats, _EXCLUDE_DEVICES
|
||||
|
||||
|
||||
_PLAIDML_INITIALIZED: bool = False
|
||||
|
|
@ -40,7 +40,7 @@ def setup_plaidml(log_level: str, exclude_devices: List[int]) -> None:
|
|||
logger.info("Successfully set up for PlaidML")
|
||||
|
||||
|
||||
class AMDStats(GPUStats):
|
||||
class AMDStats(_GPUStats):
|
||||
""" Holds information and statistics about AMD GPU(s) available on the currently
|
||||
running system.
|
||||
|
||||
|
|
@ -104,9 +104,9 @@ class AMDStats(GPUStats):
|
|||
return retval
|
||||
|
||||
@property
|
||||
def _all_vram(self) -> List[float]:
|
||||
def _all_vram(self) -> List[int]:
|
||||
""" list: The VRAM of each GPU device that PlaidML has discovered. """
|
||||
return [int(device.get("globalMemSize", 0)) / (1024 * 1024)
|
||||
return [int(device.get("globalMemSize", 0) / (1024 * 1024))
|
||||
for device in self._device_details]
|
||||
|
||||
@property
|
||||
|
|
@ -159,7 +159,7 @@ class AMDStats(GPUStats):
|
|||
self._log("debug", "Setting PlaidML Default Logger")
|
||||
|
||||
plaidml.DEFAULT_LOG_HANDLER = logging.getLogger("plaidml_root")
|
||||
plaidml.DEFAULT_LOG_HANDLER.propagate = 0
|
||||
plaidml.DEFAULT_LOG_HANDLER.propagate = False
|
||||
|
||||
numeric_level = getattr(logging, self._log_level, None)
|
||||
if numeric_level < 10: # DEBUG Logging
|
||||
|
|
@ -344,8 +344,8 @@ class AMDStats(GPUStats):
|
|||
str
|
||||
The current AMD GPU driver versions
|
||||
"""
|
||||
drivers = [device.get("driverVersion", "No Driver Found")
|
||||
for device in self._device_details]
|
||||
drivers = "|".join([device.get("driverVersion", "No Driver Found")
|
||||
for device in self._device_details])
|
||||
self._log("debug", f"GPU Drivers: {drivers}")
|
||||
return drivers
|
||||
|
||||
|
|
@ -361,7 +361,7 @@ class AMDStats(GPUStats):
|
|||
self._log("debug", f"GPU Devices: {names}")
|
||||
return names
|
||||
|
||||
def _get_vram(self) -> List[float]:
|
||||
def _get_vram(self) -> List[int]:
|
||||
""" Obtain the VRAM in Megabytes for each connected AMD GPU as identified in
|
||||
:attr:`_handles`.
|
||||
|
||||
|
|
@ -374,7 +374,7 @@ class AMDStats(GPUStats):
|
|||
self._log("debug", f"GPU VRAM: {vram}")
|
||||
return vram
|
||||
|
||||
def _get_free_vram(self) -> List[float]:
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected AMD
|
||||
GPU.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Collects and returns Information on available Apple Silicon SoCs in Apple Macs. """
|
||||
from typing import List, Optional
|
||||
from typing import Any, List
|
||||
|
||||
import os
|
||||
import psutil
|
||||
|
|
@ -8,13 +8,13 @@ import tensorflow as tf
|
|||
|
||||
from lib.utils import FaceswapError
|
||||
|
||||
from ._base import GPUStats
|
||||
from ._base import _GPUStats
|
||||
|
||||
|
||||
_METAL_INITIALIZED: bool = False
|
||||
|
||||
|
||||
class AppleSiliconStats(GPUStats):
|
||||
class AppleSiliconStats(_GPUStats):
|
||||
""" Holds information and statistics about Apple Silicon SoC(s) available on the currently
|
||||
running Apple system.
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ class AppleSiliconStats(GPUStats):
|
|||
"""
|
||||
def __init__(self, log: bool = True) -> None:
|
||||
# Following attribute set in :func:``_initialize``
|
||||
self._tf_devices: Optional(List[str]) = None
|
||||
self._tf_devices: List[Any] = []
|
||||
|
||||
super().__init__(log=log)
|
||||
|
||||
|
|
@ -155,7 +155,7 @@ class AppleSiliconStats(GPUStats):
|
|||
self._log("debug", f"GPU Devices: {names}")
|
||||
return names
|
||||
|
||||
def _get_vram(self) -> List[float]:
|
||||
def _get_vram(self) -> List[int]:
|
||||
""" Obtain the VRAM in Megabytes for each available Apple Silicon SoC(s) as identified in
|
||||
:attr:`_handles`.
|
||||
|
||||
|
|
@ -170,12 +170,12 @@ class AppleSiliconStats(GPUStats):
|
|||
list
|
||||
The RAM in Megabytes for each available Apple Silicon SoC
|
||||
"""
|
||||
vram = [(psutil.virtual_memory().total / self._device_count) / (1024 * 1024)
|
||||
vram = [int((psutil.virtual_memory().total / self._device_count) / (1024 * 1024))
|
||||
for _ in range(self._device_count)]
|
||||
self._log("debug", f"SoC RAM: {vram}")
|
||||
return vram
|
||||
|
||||
def _get_free_vram(self) -> List[float]:
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
""" Obtain the amount of VRAM that is available, in Megabytes, for each available Apple
|
||||
Silicon SoC.
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ class AppleSiliconStats(GPUStats):
|
|||
List of `float`s containing the amount of RAM available, in Megabytes, for each
|
||||
available SoC as corresponding to the values in :attr:`_handles
|
||||
"""
|
||||
vram = [(psutil.virtual_memory().available / self._device_count) / (1024 * 1024)
|
||||
vram = [int((psutil.virtual_memory().available / self._device_count) / (1024 * 1024))
|
||||
for _ in range(self._device_count)]
|
||||
self._log("debug", f"SoC RAM free: {vram}")
|
||||
return vram
|
||||
|
|
|
|||
|
|
@ -4,19 +4,19 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from ._base import GPUStats
|
||||
from ._base import _GPUStats
|
||||
|
||||
|
||||
class CPUStats(GPUStats):
|
||||
class CPUStats(_GPUStats):
|
||||
""" Holds information and statistics about the CPU on the currently running system.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The information held here is not useful, but GPUStats is dynamically imported depending on the
|
||||
backend used, so we need to make sure this class is available for Faceswap run on the CPU
|
||||
The information held here is not useful, but _GPUStats is dynamically imported depending on
|
||||
the backend used, so we need to make sure this class is available for Faceswap run on the CPU
|
||||
Backend.
|
||||
|
||||
The base :class:`GPUStats` handles the dummying in of information when no GPU is detected.
|
||||
The base :class:`_GPUStats` handles the dummying in of information when no GPU is detected.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -49,7 +49,7 @@ class CPUStats(GPUStats):
|
|||
list
|
||||
An empty list for CPU Backends
|
||||
"""
|
||||
handles = []
|
||||
handles: list = []
|
||||
self._log("debug", f"GPU Handles found: {len(handles)}")
|
||||
return handles
|
||||
|
||||
|
|
@ -73,11 +73,11 @@ class CPUStats(GPUStats):
|
|||
list
|
||||
An empty list for CPU backends
|
||||
"""
|
||||
names = []
|
||||
names: List[str] = []
|
||||
self._log("debug", f"GPU Devices: {names}")
|
||||
return names
|
||||
|
||||
def _get_vram(self) -> List[float]:
|
||||
def _get_vram(self) -> List[int]:
|
||||
""" Obtain the RAM in Megabytes for the running system.
|
||||
|
||||
Returns
|
||||
|
|
@ -85,11 +85,11 @@ class CPUStats(GPUStats):
|
|||
list
|
||||
An empty list for CPU backends
|
||||
"""
|
||||
vram = []
|
||||
vram: List[int] = []
|
||||
self._log("debug", f"GPU VRAM: {vram}")
|
||||
return vram
|
||||
|
||||
def _get_free_vram(self) -> List[float]:
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
""" Obtain the amount of RAM that is available, in Megabytes, for the running system.
|
||||
|
||||
Returns
|
||||
|
|
@ -97,6 +97,6 @@ class CPUStats(GPUStats):
|
|||
list
|
||||
An empty list for CPU backends
|
||||
"""
|
||||
vram = []
|
||||
vram: List[int] = []
|
||||
self._log("debug", f"GPU VRAM free: {vram}")
|
||||
return vram
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ import pynvml
|
|||
|
||||
from lib.utils import FaceswapError
|
||||
|
||||
from ._base import GPUStats
|
||||
from ._base import _GPUStats
|
||||
|
||||
|
||||
class NvidiaStats(GPUStats):
|
||||
class NvidiaStats(_GPUStats):
|
||||
""" Holds information and statistics about Nvidia GPU(s) available on the currently
|
||||
running system.
|
||||
|
||||
|
|
@ -125,7 +125,7 @@ class NvidiaStats(GPUStats):
|
|||
self._log("debug", f"GPU Devices: {names}")
|
||||
return names
|
||||
|
||||
def _get_vram(self) -> List[float]:
|
||||
def _get_vram(self) -> List[int]:
|
||||
""" Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in
|
||||
:attr:`_handles`.
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ class NvidiaStats(GPUStats):
|
|||
self._log("debug", f"GPU VRAM: {vram}")
|
||||
return vram
|
||||
|
||||
def _get_free_vram(self) -> List[float]:
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia
|
||||
GPU.
|
||||
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ import pynvx
|
|||
|
||||
from lib.utils import FaceswapError
|
||||
|
||||
from ._base import GPUStats
|
||||
from ._base import _GPUStats
|
||||
|
||||
|
||||
class NvidiaAppleStats(GPUStats):
|
||||
class NvidiaAppleStats(_GPUStats):
|
||||
""" Holds information and statistics about Nvidia GPU(s) available on the currently
|
||||
running Apple system.
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ class NvidiaAppleStats(GPUStats):
|
|||
self._log("debug", f"GPU Devices: {names}")
|
||||
return names
|
||||
|
||||
def _get_vram(self) -> List[float]:
|
||||
def _get_vram(self) -> List[int]:
|
||||
""" Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in
|
||||
:attr:`_handles`.
|
||||
|
||||
|
|
@ -120,7 +120,7 @@ class NvidiaAppleStats(GPUStats):
|
|||
self._log("debug", f"GPU VRAM: {vram}")
|
||||
return vram
|
||||
|
||||
def _get_free_vram(self) -> List[float]:
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia
|
||||
GPU.
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import plaidml
|
||||
|
|
@ -621,18 +621,155 @@ class GMSDLoss(): # pylint:disable=too-few-public-methods
|
|||
return output
|
||||
|
||||
|
||||
class LaplacianPyramidLoss(): # pylint:disable=too-few-public-methods
|
||||
""" Laplacian Pyramid Loss Function
|
||||
|
||||
Notes
|
||||
-----
|
||||
Channels last implementation on square images only.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_levels: int, Optional
|
||||
The max number of laplacian pyramid levels to use. Default: `5`
|
||||
gaussian_size: int, Optional
|
||||
The size of the gaussian kernel. Default: `5`
|
||||
gaussian_sigma: float, optional
|
||||
The gaussian sigma. Default: 2.0
|
||||
|
||||
References
|
||||
----------
|
||||
https://arxiv.org/abs/1707.05776
|
||||
https://github.com/nathanaelbosch/generative-latent-optimization/blob/master/utils.py
|
||||
"""
|
||||
def __init__(self,
|
||||
max_levels: int = 5,
|
||||
gaussian_size: int = 5,
|
||||
gaussian_sigma: float = 1.0) -> None:
|
||||
self._max_levels = max_levels
|
||||
self._weights = K.constant([np.power(2., -2 * idx) for idx in range(max_levels + 1)])
|
||||
self._gaussian_kernel = self._get_gaussian_kernel(gaussian_size, gaussian_sigma)
|
||||
self._shape: Tuple[int, ...] = ()
|
||||
|
||||
@classmethod
|
||||
def _get_gaussian_kernel(cls, size: int, sigma: float) -> plaidml.tile.Value:
|
||||
""" Obtain the base gaussian kernel for the Laplacian Pyramid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size: int, Optional
|
||||
The size of the gaussian kernel
|
||||
sigma: float
|
||||
The gaussian sigma
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`plaidml.tile.Value`
|
||||
The base single channel Gaussian kernel
|
||||
"""
|
||||
assert size % 2 == 1, ("kernel size must be uneven")
|
||||
x_1 = np.linspace(- (size // 2), size // 2, size, dtype="float32")
|
||||
x_1 /= np.sqrt(2)*sigma
|
||||
x_2 = x_1 ** 2
|
||||
kernel = np.exp(- x_2[:, None] - x_2[None, :])
|
||||
kernel /= kernel.sum()
|
||||
kernel = np.reshape(kernel, (size, size, 1, 1))
|
||||
return K.constant(kernel)
|
||||
|
||||
def _conv_gaussian(self, inputs: plaidml.tile.Value) -> plaidml.tile.Value:
|
||||
""" Perform Gaussian convolution on a batch of images.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: :class:`plaidml.tile.Value`
|
||||
The input batch of images to perform Gaussian convolution on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`plaidml.tile.Value`
|
||||
The convolved images
|
||||
"""
|
||||
channels = self._shape[-1]
|
||||
gauss = K.tile(self._gaussian_kernel, (1, 1, 1, channels))
|
||||
|
||||
# PlaidML doesn't implement replication padding like pytorch. This is an inefficient way to
|
||||
# implement it for a square guassian kernel
|
||||
size = K.int_shape(self._gaussian_kernel)[1] // 2
|
||||
padded_inputs = inputs
|
||||
for _ in range(size):
|
||||
padded_inputs = pad(padded_inputs, # noqa,pylint:disable=no-value-for-parameter,unexpected-keyword-arg
|
||||
([0, 0], [1, 1], [1, 1], [0, 0]),
|
||||
mode="REFLECT")
|
||||
|
||||
retval = K.conv2d(padded_inputs, gauss, strides=(1, 1), padding="valid")
|
||||
return retval
|
||||
|
||||
def _get_laplacian_pyramid(self, inputs: plaidml.tile.Value) -> List[plaidml.tile.Value]:
|
||||
""" Obtain the Laplacian Pyramid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: :class:`plaidml.tile.Value`
|
||||
The input batch of images to run through the Laplacian Pyramid
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
The tensors produced from the Laplacian Pyramid
|
||||
"""
|
||||
pyramid = []
|
||||
current = inputs
|
||||
for _ in range(self._max_levels):
|
||||
gauss = self._conv_gaussian(current)
|
||||
diff = current - gauss
|
||||
pyramid.append(diff)
|
||||
current = K.pool2d(gauss, (2, 2), strides=(2, 2), padding="valid", pool_mode="avg")
|
||||
pyramid.append(current)
|
||||
return pyramid
|
||||
|
||||
def __call__(self,
|
||||
y_true: plaidml.tile.Value,
|
||||
y_pred: plaidml.tile.Value) -> plaidml.tile.Value:
|
||||
""" Calculate the Laplacian Pyramid Loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true: :class:`plaidml.tile.Value`
|
||||
The ground truth value
|
||||
y_pred: :class:`plaidml.tile.Value`
|
||||
The predicted value
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `plaidml.tile.Value`
|
||||
The loss value
|
||||
"""
|
||||
if not self._shape:
|
||||
self._shape = K.int_shape(y_pred)
|
||||
pyramid_true = self._get_laplacian_pyramid(y_true)
|
||||
pyramid_pred = self._get_laplacian_pyramid(y_pred)
|
||||
|
||||
losses = K.stack([K.sum(K.abs(ppred - ptrue)) / K.cast(K.prod(K.shape(ptrue)), "float32")
|
||||
for ptrue, ppred in zip(pyramid_true, pyramid_pred)])
|
||||
loss = K.sum(losses * self._weights)
|
||||
return loss
|
||||
|
||||
|
||||
class LossWrapper(): # pylint:disable=too-few-public-methods
|
||||
""" A wrapper class for multiple keras losses to enable multiple weighted loss functions on a
|
||||
single output and masking.
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
logger.debug("Initializing: %s", self.__class__.__name__)
|
||||
self._loss_functions = []
|
||||
self._loss_weights = []
|
||||
self._mask_channels = []
|
||||
self._loss_functions: List[Callable] = []
|
||||
self._loss_weights: List[float] = []
|
||||
self._mask_channels: List[int] = []
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
def add_loss(self, function, weight=1.0, mask_channel=-1):
|
||||
def add_loss(self,
|
||||
function,
|
||||
weight: float = 1.0,
|
||||
mask_channel: int = -1) -> None:
|
||||
""" Add the given loss function with the given weight to the loss function chain.
|
||||
|
||||
Parameters
|
||||
|
|
@ -651,21 +788,23 @@ class LossWrapper(): # pylint:disable=too-few-public-methods
|
|||
self._loss_weights.append(weight)
|
||||
self._mask_channels.append(mask_channel)
|
||||
|
||||
def __call__(self, y_true, y_pred):
|
||||
def __call__(self,
|
||||
y_true: plaidml.tile.Value,
|
||||
y_pred: plaidml.tile.Value) -> plaidml.tile.Value:
|
||||
""" Call the sub loss functions for the loss wrapper.
|
||||
|
||||
Weights are returned as the weighted sum of the chosen losses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true: tensor or variable
|
||||
y_true: :class:`plaidml.tile.Value`
|
||||
The ground truth value
|
||||
y_pred: tensor or variable
|
||||
y_pred: :class:`plaidml.tile.Value`
|
||||
The predicted value
|
||||
|
||||
Returns
|
||||
-------
|
||||
tensor
|
||||
:class:`plaidml.tile.Value`
|
||||
The final loss value
|
||||
"""
|
||||
loss = 0.0
|
||||
|
|
@ -685,15 +824,19 @@ class LossWrapper(): # pylint:disable=too-few-public-methods
|
|||
return loss
|
||||
|
||||
@classmethod
|
||||
def _apply_mask(cls, y_true, y_pred, mask_channel, mask_prop=1.0):
|
||||
def _apply_mask(cls,
|
||||
y_true: plaidml.tile.Value,
|
||||
y_pred: plaidml.tile.Value,
|
||||
mask_channel: int,
|
||||
mask_prop: float = 1.0) -> Tuple[plaidml.tile.Value, plaidml.tile.Value]:
|
||||
""" Apply the mask to the input y_true and y_pred. If a mask is not required then
|
||||
return the unmasked inputs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true: tensor or variable
|
||||
y_true: :class:`plaidml.tile.Value`
|
||||
The ground truth value
|
||||
y_pred: tensor or variable
|
||||
y_pred: :class:`plaidml.tile.Value`
|
||||
The predicted value
|
||||
mask_channel: int
|
||||
The channel within y_true that the required mask resides in
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
|
@ -558,6 +558,136 @@ class GMSDLoss(tf.keras.losses.Loss): # pylint:disable=too-few-public-methods
|
|||
return output
|
||||
|
||||
|
||||
class LaplacianPyramidLoss(): # pylint:disable=too-few-public-methods
|
||||
""" Laplacian Pyramid Loss Function
|
||||
|
||||
Notes
|
||||
-----
|
||||
Channels last implementation on square images only.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_levels: int, Optional
|
||||
The max number of laplacian pyramid levels to use. Default: `5`
|
||||
gaussian_size: int, Optional
|
||||
The size of the gaussian kernel. Default: `5`
|
||||
gaussian_sigma: float, optional
|
||||
The gaussian sigma. Default: 2.0
|
||||
|
||||
References
|
||||
----------
|
||||
https://arxiv.org/abs/1707.05776
|
||||
https://github.com/nathanaelbosch/generative-latent-optimization/blob/master/utils.py
|
||||
"""
|
||||
def __init__(self,
|
||||
max_levels: int = 5,
|
||||
gaussian_size: int = 5,
|
||||
gaussian_sigma: float = 1.0) -> None:
|
||||
self._max_levels = max_levels
|
||||
self._weights = K.constant([np.power(2., -2 * idx) for idx in range(max_levels + 1)])
|
||||
self._gaussian_kernel = self._get_gaussian_kernel(gaussian_size, gaussian_sigma)
|
||||
|
||||
@classmethod
|
||||
def _get_gaussian_kernel(cls, size: int, sigma: float) -> tf.Tensor:
|
||||
""" Obtain the base gaussian kernel for the Laplacian Pyramid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size: int, Optional
|
||||
The size of the gaussian kernel
|
||||
sigma: float
|
||||
The gaussian sigma
|
||||
|
||||
Returns
|
||||
-------
|
||||
tf.Tensor
|
||||
The base single channel Gaussian kernel
|
||||
"""
|
||||
assert size % 2 == 1, ("kernel size must be uneven")
|
||||
x_1 = np.linspace(- (size // 2), size // 2, size, dtype="float32")
|
||||
x_1 /= np.sqrt(2)*sigma
|
||||
x_2 = x_1 ** 2
|
||||
kernel = np.exp(- x_2[:, None] - x_2[None, :])
|
||||
kernel /= kernel.sum()
|
||||
kernel = np.reshape(kernel, (size, size, 1, 1))
|
||||
return K.constant(kernel)
|
||||
|
||||
def _conv_gaussian(self, inputs: tf.Tensor) -> tf.Tensor:
|
||||
""" Perform Gaussian convolution on a batch of images.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: :class:`tf.Tensor`
|
||||
The input batch of images to perform Gaussian convolution on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`tf.Tensor`
|
||||
The convolved images
|
||||
"""
|
||||
channels = K.int_shape(inputs)[-1]
|
||||
gauss = K.tile(self._gaussian_kernel, (1, 1, 1, channels))
|
||||
|
||||
# TF doesn't implement replication padding like pytorch. This is an inefficient way to
|
||||
# implement it for a square guassian kernel
|
||||
size = self._gaussian_kernel.shape[1] // 2
|
||||
padded_inputs = inputs
|
||||
for _ in range(size):
|
||||
padded_inputs = tf.pad(padded_inputs, # noqa,pylint:disable=no-value-for-parameter,unexpected-keyword-arg
|
||||
([0, 0], [1, 1], [1, 1], [0, 0]),
|
||||
mode="SYMMETRIC")
|
||||
|
||||
retval = K.conv2d(padded_inputs, gauss, strides=1, padding="valid")
|
||||
return retval
|
||||
|
||||
def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> List[tf.Tensor]:
|
||||
""" Obtain the Laplacian Pyramid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: :class:`tf.Tensor`
|
||||
The input batch of images to run through the Laplacian Pyramid
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
The tensors produced from the Laplacian Pyramid
|
||||
"""
|
||||
pyramid = []
|
||||
current = inputs
|
||||
for _ in range(self._max_levels):
|
||||
gauss = self._conv_gaussian(current)
|
||||
diff = current - gauss
|
||||
pyramid.append(diff)
|
||||
current = K.pool2d(gauss, (2, 2), strides=(2, 2), padding="valid", pool_mode="avg")
|
||||
pyramid.append(current)
|
||||
return pyramid
|
||||
|
||||
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
|
||||
""" Calculate the Laplacian Pyramid Loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true: :class:`tf.Tensor`
|
||||
The ground truth value
|
||||
y_pred: :class:`tf.Tensor`
|
||||
The predicted value
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `tf.Tensor`
|
||||
The loss value
|
||||
"""
|
||||
pyramid_true = self._get_laplacian_pyramid(y_true)
|
||||
pyramid_pred = self._get_laplacian_pyramid(y_pred)
|
||||
|
||||
losses = K.stack([K.sum(K.abs(ppred - ptrue)) / K.cast(K.prod(K.shape(ptrue)), "float32")
|
||||
for ptrue, ppred in zip(pyramid_true, pyramid_pred)])
|
||||
loss = K.sum(losses * self._weights)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class LossWrapper():
|
||||
""" A wrapper class for multiple keras losses to enable multiple masked weighted loss
|
||||
functions on a single output.
|
||||
|
|
@ -586,7 +716,7 @@ class LossWrapper():
|
|||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
def add_loss(self,
|
||||
function: tf.keras.losses.Loss,
|
||||
function: Callable,
|
||||
weight: float = 1.0,
|
||||
mask_channel: int = -1) -> None:
|
||||
""" Add the given loss function with the given weight to the loss function chain.
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ class Extractor():
|
|||
For align/mask (2nd/3rd pass operations) the :attr:`ExtractMedia.detected_faces` should
|
||||
also be populated by calling :func:`ExtractMedia.set_detected_faces`.
|
||||
"""
|
||||
qname = "extract{}_{}_in".format(self._instance, self._current_phase[0])
|
||||
qname = f"extract{self._instance}_{self._current_phase[0]}_in"
|
||||
retval = self._queues[qname]
|
||||
logger.trace("%s: %s", qname, retval)
|
||||
return retval
|
||||
|
|
@ -193,7 +193,7 @@ class Extractor():
|
|||
The batch size to use for this plugin type
|
||||
"""
|
||||
logger.debug("Overriding batchsize for plugin_type: %s to: %s", plugin_type, batchsize)
|
||||
plugin = getattr(self, "_{}".format(plugin_type))
|
||||
plugin = getattr(self, f"_{plugin_type}")
|
||||
plugin.batchsize = batchsize
|
||||
|
||||
def launch(self):
|
||||
|
|
@ -283,10 +283,10 @@ class Extractor():
|
|||
@property
|
||||
def _vram_per_phase(self):
|
||||
""" dict: The amount of vram required for each phase in :attr:`_flow`. """
|
||||
retval = dict()
|
||||
retval = {}
|
||||
for phase in self._flow:
|
||||
plugin_type, idx = self._get_plugin_type_and_index(phase)
|
||||
attr = getattr(self, "_{}".format(plugin_type))
|
||||
attr = getattr(self, f"_{plugin_type}")
|
||||
attr = attr[idx] if idx is not None else attr
|
||||
retval[phase] = attr.vram
|
||||
logger.trace(retval)
|
||||
|
|
@ -322,10 +322,9 @@ class Extractor():
|
|||
def _output_queue(self):
|
||||
""" Return the correct output queue depending on the current phase """
|
||||
if self.final_pass:
|
||||
qname = "extract{}_{}_out".format(self._instance, self._final_phase)
|
||||
qname = f"extract{self._instance}_{self._final_phase}_out"
|
||||
else:
|
||||
qname = "extract{}_{}_in".format(self._instance,
|
||||
self._phases[self._phase_index + 1][0])
|
||||
qname = f"extract{self._instance}_{self._phases[self._phase_index + 1][0]}_in"
|
||||
retval = self._queues[qname]
|
||||
logger.trace("%s: %s", qname, retval)
|
||||
return retval
|
||||
|
|
@ -336,7 +335,7 @@ class Extractor():
|
|||
retval = []
|
||||
for phase in self._flow:
|
||||
plugin_type, idx = self._get_plugin_type_and_index(phase)
|
||||
attr = getattr(self, "_{}".format(plugin_type))
|
||||
attr = getattr(self, f"_{plugin_type}")
|
||||
attr = attr[idx] if idx is not None else attr
|
||||
retval.append(attr)
|
||||
logger.trace("All Plugins: %s", retval)
|
||||
|
|
@ -348,7 +347,7 @@ class Extractor():
|
|||
retval = []
|
||||
for phase in self._current_phase:
|
||||
plugin_type, idx = self._get_plugin_type_and_index(phase)
|
||||
attr = getattr(self, "_{}".format(plugin_type))
|
||||
attr = getattr(self, f"_{plugin_type}")
|
||||
retval.append(attr[idx] if idx is not None else attr)
|
||||
logger.trace("Active plugins: %s", retval)
|
||||
return retval
|
||||
|
|
@ -362,7 +361,7 @@ class Extractor():
|
|||
retval.append("detect")
|
||||
if aligner is not None and aligner.lower() != "none":
|
||||
retval.append("align")
|
||||
retval.extend(["mask_{}".format(idx)
|
||||
retval.extend([f"mask_{idx}"
|
||||
for idx, mask in enumerate(masker)
|
||||
if mask is not None and mask.lower() != "none"])
|
||||
logger.debug("flow: %s", retval)
|
||||
|
|
@ -400,9 +399,9 @@ class Extractor():
|
|||
|
||||
def _add_queues(self):
|
||||
""" Add the required processing queues to Queue Manager """
|
||||
queues = dict()
|
||||
tasks = ["extract{}_{}_in".format(self._instance, phase) for phase in self._flow]
|
||||
tasks.append("extract{}_{}_out".format(self._instance, self._final_phase))
|
||||
queues = {}
|
||||
tasks = [f"extract{self._instance}_{phase}_in" for phase in self._flow]
|
||||
tasks.append(f"extract{self._instance}_{self._final_phase}_out")
|
||||
for task in tasks:
|
||||
# Limit queue size to avoid stacking ram
|
||||
queue_manager.add_queue(task, maxsize=self._queue_size)
|
||||
|
|
@ -552,17 +551,17 @@ class Extractor():
|
|||
def _launch_plugin(self, phase):
|
||||
""" Launch an extraction plugin """
|
||||
logger.debug("Launching %s plugin", phase)
|
||||
in_qname = "extract{}_{}_in".format(self._instance, phase)
|
||||
in_qname = f"extract{self._instance}_{phase}_in"
|
||||
if phase == self._final_phase:
|
||||
out_qname = "extract{}_{}_out".format(self._instance, self._final_phase)
|
||||
out_qname = f"extract{self._instance}_{self._final_phase}_out"
|
||||
else:
|
||||
next_phase = self._flow[self._flow.index(phase) + 1]
|
||||
out_qname = "extract{}_{}_in".format(self._instance, next_phase)
|
||||
out_qname = f"extract{self._instance}_{next_phase}_in"
|
||||
logger.debug("in_qname: %s, out_qname: %s", in_qname, out_qname)
|
||||
kwargs = dict(in_queue=self._queues[in_qname], out_queue=self._queues[out_qname])
|
||||
|
||||
plugin_type, idx = self._get_plugin_type_and_index(phase)
|
||||
plugin = getattr(self, "_{}".format(plugin_type))
|
||||
plugin = getattr(self, f"_{plugin_type}")
|
||||
plugin = plugin[idx] if idx is not None else plugin
|
||||
plugin.initialize(**kwargs)
|
||||
plugin.start()
|
||||
|
|
@ -645,7 +644,7 @@ class Extractor():
|
|||
logger.debug("Remaining VRAM to allocate: %sMB", remaining)
|
||||
|
||||
if batchsizes != requested_batchsizes:
|
||||
text = ", ".join(["{}: {}".format(plugin.__class__.__name__, batchsize)
|
||||
text = ", ".join([f"{plugin.__class__.__name__}: {batchsize}"
|
||||
for plugin, batchsize in zip(plugins, batchsizes)])
|
||||
for plugin, batchsize in zip(plugins, batchsizes):
|
||||
plugin.batchsize = batchsize
|
||||
|
|
@ -726,7 +725,7 @@ class ExtractMedia():
|
|||
A copy of :attr:`image` in the requested :attr:`color_format`
|
||||
"""
|
||||
logger.trace("Requested color format '%s' for frame '%s'", color_format, self._filename)
|
||||
image = getattr(self, "_image_as_{}".format(color_format.lower()))()
|
||||
image = getattr(self, f"_image_as_{color_format.lower()}")()
|
||||
return image
|
||||
|
||||
def add_detected_faces(self, faces):
|
||||
|
|
|
|||
|
|
@ -21,6 +21,13 @@ _LOSS_HELP = dict(
|
|||
"The L_inf norm will reduce the largest individual pixel error in an image. As "
|
||||
"each largest error is minimized sequentially, the overall error is improved. This loss "
|
||||
"will be extremely focused on outliers."),
|
||||
laploss=(
|
||||
"Laplacian Pyramid Loss. Attempts to improve results by focussing on edges using "
|
||||
"Laplacian Pyramids. As this loss function gives priority to edges over other low-"
|
||||
"frequency information, like color, it should not be used on its own. The original "
|
||||
"implementation uses this loss as a complimentary function to MSE. "
|
||||
"Ref: Optimizing the Latent Space of Generative Networks "
|
||||
"https://arxiv.org/abs/1707.05776"),
|
||||
logcosh=(
|
||||
"log(cosh(x)) acts similar to MSE for small errors and to MAE for large errors. Like "
|
||||
"MSE, it is very stable and prevents overshoots when errors are near zero. Like MAE, it "
|
||||
|
|
|
|||
|
|
@ -53,15 +53,16 @@ class Loss():
|
|||
def __init__(self, config: dict) -> None:
|
||||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
self._config = config
|
||||
self._loss_dict = dict(mae=k_losses.mean_absolute_error,
|
||||
mse=k_losses.mean_squared_error,
|
||||
logcosh=k_losses.logcosh,
|
||||
smooth_loss=losses.GeneralizedLoss(),
|
||||
self._loss_dict = dict(gmsd=losses.GMSDLoss(),
|
||||
l_inf_norm=losses.LInfNorm(),
|
||||
ssim=losses.DSSIMObjective(),
|
||||
laploss=losses.LaplacianPyramidLoss(),
|
||||
logcosh=k_losses.logcosh,
|
||||
ms_ssim=losses.MSSIMLoss(),
|
||||
gmsd=losses.GMSDLoss(),
|
||||
pixel_gradient_diff=losses.GradientLoss())
|
||||
mae=k_losses.mean_absolute_error,
|
||||
mse=k_losses.mean_squared_error,
|
||||
pixel_gradient_diff=losses.GradientLoss(),
|
||||
ssim=losses.DSSIMObjective(),
|
||||
smooth_loss=losses.GeneralizedLoss(),)
|
||||
self._mask_channels = self._get_mask_channels()
|
||||
self._inputs: List[keras.layers.Layer] = []
|
||||
self._names: List[str] = []
|
||||
|
|
|
|||
10
setup.cfg
10
setup.cfg
|
|
@ -10,5 +10,15 @@ exclude = .git, __pycache__
|
|||
ignore_missing_imports = True
|
||||
[mypy-keras.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-psutil.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-plaidml.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-tqdm.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-pynvml.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-pynvx.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-cv2.*]
|
||||
ignore_missing_imports = True
|
||||
|
|
|
|||
|
|
@ -41,11 +41,18 @@ def test_loss_output(loss_func, output_shape):
|
|||
assert output.dtype == "float32" and not np.any(np.isnan(output))
|
||||
|
||||
|
||||
_LWPARAMS = [losses.GeneralizedLoss(), losses.GradientLoss(), losses.GMSDLoss(),
|
||||
losses.LInfNorm(), k_losses.mean_absolute_error, k_losses.mean_squared_error,
|
||||
k_losses.logcosh, losses.DSSIMObjective(), losses.MSSIMLoss()]
|
||||
_LWIDS = ["GeneralizedLoss", "GradientLoss", "GMSDLoss", "LInfNorm", "mae", "mse", "logcosh",
|
||||
"DSSIMObjective", "MS-SSIM"]
|
||||
_LWPARAMS = [losses.DSSIMObjective(),
|
||||
losses.GeneralizedLoss(),
|
||||
losses.GMSDLoss(),
|
||||
losses.GradientLoss(),
|
||||
losses.LaplacianPyramidLoss(),
|
||||
losses.LInfNorm(),
|
||||
k_losses.logcosh,
|
||||
k_losses.mean_absolute_error,
|
||||
k_losses.mean_squared_error,
|
||||
losses.MSSIMLoss()]
|
||||
_LWIDS = ["DSSIMObjective", "GeneralizedLoss", "GMSDLoss", "GradientLoss", "LaplacianPyramidLoss",
|
||||
"LInfNorm", "logcosh", "mae", "mse", "MS-SSIM"]
|
||||
_LWIDS = [f"{loss}[{get_backend().upper()}]" for loss in _LWIDS]
|
||||
|
||||
|
||||
|
|
@ -57,8 +64,8 @@ def test_loss_wrapper(loss_func):
|
|||
pytest.skip("GMSD Loss is not currently compatible with PlaidML")
|
||||
if hasattr(loss_func, "__name__") and loss_func.__name__ == "logcosh":
|
||||
pytest.skip("LogCosh Loss is not currently compatible with PlaidML")
|
||||
y_a = K.variable(np.random.random((2, 16, 16, 4)))
|
||||
y_b = K.variable(np.random.random((2, 16, 16, 3)))
|
||||
y_a = K.variable(np.random.random((2, 64, 64, 4)))
|
||||
y_b = K.variable(np.random.random((2, 64, 64, 3)))
|
||||
p_loss = losses.LossWrapper()
|
||||
p_loss.add_loss(loss_func, 1.0, -1)
|
||||
p_loss.add_loss(k_losses.mean_squared_error, 2.0, 3)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from urllib.request import urlretrieve
|
|||
import os
|
||||
from os.path import join as pathjoin, expanduser
|
||||
|
||||
_TRAIN_ARGS = (1, 1) if os.environ.get("FACESWAP_BACKEND", "cpu").lower() == "amd" else (4, 4)
|
||||
FAIL_COUNT = 0
|
||||
TEST_COUNT = 0
|
||||
_COLORS = {
|
||||
|
|
@ -167,18 +168,21 @@ def main():
|
|||
|
||||
run_test(
|
||||
"Train lightweight model for 1 iteration with WTL.",
|
||||
train_args(
|
||||
"lightweight", pathjoin(vid_base, "model"),
|
||||
pathjoin(vid_base, "faces"), extra_args="-wl"
|
||||
)
|
||||
)
|
||||
train_args("lightweight",
|
||||
pathjoin(vid_base, "model"),
|
||||
pathjoin(vid_base, "faces"),
|
||||
iterations=_TRAIN_ARGS[0],
|
||||
batchsize=_TRAIN_ARGS[1],
|
||||
extra_args="-wl"))
|
||||
|
||||
was_trained = run_test(
|
||||
"Train lightweight model for 1 iterations WITHOUT WTL.",
|
||||
train_args(
|
||||
"lightweight", pathjoin(vid_base, "model"), pathjoin(vid_base, "faces")
|
||||
)
|
||||
)
|
||||
train_args("lightweight",
|
||||
pathjoin(vid_base, "model"),
|
||||
pathjoin(vid_base, "faces"),
|
||||
iterations=_TRAIN_ARGS[0],
|
||||
batchsize=_TRAIN_ARGS[1],
|
||||
extra_args="-wl"))
|
||||
|
||||
if was_trained:
|
||||
run_test(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user