mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
Add AlexNet + SqueezeNet definitions
This commit is contained in:
parent
1d434b73a4
commit
ef79a3d8cb
|
|
@ -55,7 +55,7 @@ model.losses module
|
|||
-------------------
|
||||
|
||||
The losses listed here are generated from the docstrings in :mod:`lib.model.losses_tf`, however
|
||||
the functions are excactly the same for :mod:`lib.model.losses_plaid`. The correct loss module will
|
||||
the functions are exactly the same for :mod:`lib.model.losses_plaid`. The correct loss module will
|
||||
be imported as :mod:`lib.model.losses` depending on the backend in use.
|
||||
|
||||
.. rubric:: Module Summary
|
||||
|
|
@ -63,14 +63,32 @@ be imported as :mod:`lib.model.losses` depending on the backend in use.
|
|||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
~lib.model.losses_tf.DSSIMObjective
|
||||
~lib.model.losses_tf.GeneralizedLoss
|
||||
~lib.model.losses_tf.GMSDLoss
|
||||
~lib.model.losses_tf.GradientLoss
|
||||
~lib.model.losses_tf.LInfNorm
|
||||
~lib.model.losses_tf.LossWrapper
|
||||
~lib.model.loss.loss_tf.DSSIMObjective
|
||||
~lib.model.loss.loss_tf.FocalFrequencyLoss
|
||||
~lib.model.loss.loss_tf.GeneralizedLoss
|
||||
~lib.model.loss.loss_tf.GMSDLoss
|
||||
~lib.model.loss.loss_tf.GradientLoss
|
||||
~lib.model.loss.loss_tf.LaplacianPyramidLoss
|
||||
~lib.model.loss.loss_tf.LInfNorm
|
||||
~lib.model.loss.loss_tf.LossWrapper
|
||||
|
||||
.. automodule:: lib.model.losses_tf
|
||||
.. automodule:: lib.model.loss.loss_tf
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
model.nets module
|
||||
-----------------
|
||||
|
||||
.. rubric:: Module Summary
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
~lib.model.nets.AlexNet
|
||||
~lib.model.nets.SqueezeNet
|
||||
|
||||
.. automodule:: lib.model.nets
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,17 @@ The Train Package handles the Model and Trainer plugins for training models in F
|
|||
model package
|
||||
=============
|
||||
|
||||
This package contains various helper functions that plugins can inherit from
|
||||
|
||||
.. rubric:: Module Summary
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
~plugins.train.model._base.model
|
||||
~plugins.train.model._base.settings
|
||||
~plugins.train.model._base.io
|
||||
|
||||
model._base.model module
|
||||
------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|||
|
||||
|
||||
class DSSIMObjective(): # pylint:disable=too-few-public-methods
|
||||
""" DSSIM and MS-DSSIM Loss Functions
|
||||
""" DSSIM Loss Function
|
||||
|
||||
Difference of Structural Similarity (DSSIM loss function).
|
||||
|
||||
|
|
@ -678,13 +678,12 @@ class LInfNorm(): # pylint:disable=too-few-public-methods
|
|||
return loss
|
||||
|
||||
|
||||
class LogCosh():
|
||||
class LogCosh(): # pylint:disable=too-few-public-methods
|
||||
"""Logarithm of the hyperbolic cosine of the prediction error.
|
||||
|
||||
`log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
|
||||
to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
|
||||
like the mean squared error, but will not be so strongly affected by the
|
||||
occasional wildly incorrect prediction.
|
||||
`log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and to `abs(x) - log(2)`
|
||||
for large `x`. This means that 'logcosh' works mostly like the mean squared error, but will not
|
||||
be so strongly affected by the occasional wildly incorrect prediction.
|
||||
"""
|
||||
def __call__(self,
|
||||
y_true: plaidml.tile.Value,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DSSIMObjective(): # pylint:disable=too-few-public-methods
|
||||
""" DSSIM and MS-DSSIM Loss Functions
|
||||
""" DSSIM Loss Functions
|
||||
|
||||
Difference of Structural Similarity (DSSIM loss function).
|
||||
|
||||
|
|
|
|||
212
lib/model/nets.py
Normal file
212
lib/model/nets.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Ports of existing NN Architecture for use in faceswap.py """
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from lib.utils import get_backend
|
||||
|
||||
if get_backend() == "amd":
|
||||
from keras.layers import Concatenate, Conv2D, Input, MaxPool2D, ZeroPadding2D
|
||||
from keras.models import Model
|
||||
from plaidml.tile import Value as Tensor
|
||||
else:
|
||||
# Ignore linting errors from Tensorflow's thoroughly broken import system
|
||||
from tensorflow.keras.layers import Concatenate, Conv2D, Input, MaxPool2D, ZeroPadding2D # noqa pylint:disable=no-name-in-module,import-error
|
||||
from tensorflow.keras.models import Model # noqa pylint:disable=no-name-in-module,import-error
|
||||
from tensorflow import Tensor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _net(): # pylint:disable=too-few-public-methods
|
||||
""" Base class for existing NeuralNet architecture
|
||||
|
||||
Notes
|
||||
-----
|
||||
All architectures assume channels_last format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_shape, Tuple, optional
|
||||
The input shape for the model. Default: ``None``
|
||||
"""
|
||||
def __init__(self,
|
||||
input_shape: Optional[Tuple[int, int, int]] = None) -> None:
|
||||
logger.debug("Initializing: %s (input_shape: %s)", self.__class__.__name__, input_shape)
|
||||
self._input_shape = (None, None, 3) if input_shape is None else input_shape
|
||||
assert len(self._input_shape) == 3 and self._input_shape[-1] == 3, (
|
||||
"Input shape must be in the format (height, width, channels) and the number of "
|
||||
f"channels must equal 3. Received: {self._input_shape}")
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
|
||||
class AlexNet(_net): # pylint:disable=too-few-public-methods
|
||||
""" AlexNet ported from torchvision version.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This port only contains the features portion of the model.
|
||||
|
||||
Reference
|
||||
---------
|
||||
https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_shape, Tuple, optional
|
||||
The input shape for the model. Default: ``None``
|
||||
"""
|
||||
def __init__(self, input_shape: Optional[Tuple[int, int, int]] = None) -> None:
|
||||
super().__init__(input_shape)
|
||||
self._feature_indices = [0, 3, 6, 8, 10] # For naming equivalent to PyTorch
|
||||
self._filters = [64, 192, 384, 256, 256] # Filters at each block
|
||||
|
||||
@classmethod
|
||||
def _conv_block(cls,
|
||||
inputs: Tensor,
|
||||
padding: int,
|
||||
filters: int,
|
||||
kernel_size: int,
|
||||
strides: int,
|
||||
block_idx: int,
|
||||
max_pool: bool) -> Tensor:
|
||||
"""
|
||||
The Convolutional block for AlexNet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: :class:`plaidml.tile.Value` or :class:`tf.Tensor`
|
||||
The input tensor to the block
|
||||
padding: int
|
||||
The amount of zero paddin to apply prior to convolution
|
||||
filters: int
|
||||
The number of filters to apply during convolution
|
||||
kernel_size: int
|
||||
The kernel size of the convolution
|
||||
strides: int
|
||||
The number of strides for the convolution
|
||||
block_idx: int
|
||||
The index of the current block (for standardized naming convention)
|
||||
max_pool: bool
|
||||
``True`` to apply a max pooling layer at the beginning of the block otherwise ``False``
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`plaidml.tile.Value` or :class:`tf.Tensor`
|
||||
The output of the Convolutional block
|
||||
"""
|
||||
name = f"features.{block_idx}"
|
||||
var_x = inputs
|
||||
if max_pool:
|
||||
var_x = MaxPool2D(pool_size=3, strides=2, name=f"{name}.pool")(var_x)
|
||||
var_x = ZeroPadding2D(padding=padding, name=f"{name}.pad")(var_x)
|
||||
var_x = Conv2D(filters,
|
||||
kernel_size=kernel_size,
|
||||
strides=strides,
|
||||
padding="valid",
|
||||
activation="relu",
|
||||
name=name)(var_x)
|
||||
return var_x
|
||||
|
||||
def __call__(self) -> Model:
|
||||
""" Create the AlexNet Model
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`keras.models.Model`
|
||||
The compiled AlexNet model
|
||||
"""
|
||||
inputs = Input(self._input_shape)
|
||||
var_x = inputs
|
||||
kernel_size = 11
|
||||
strides = 4
|
||||
|
||||
for idx, (filters, block_idx) in enumerate(zip(self._filters, self._feature_indices)):
|
||||
padding = 2 if idx < 2 else 1
|
||||
do_max_pool = 0 < idx < 3
|
||||
var_x = self._conv_block(var_x,
|
||||
padding,
|
||||
filters,
|
||||
kernel_size,
|
||||
strides,
|
||||
block_idx,
|
||||
do_max_pool)
|
||||
kernel_size = max(3, kernel_size // 2)
|
||||
strides = 1
|
||||
return Model(inputs=inputs, outputs=[var_x])
|
||||
|
||||
|
||||
class SqueezeNet(_net): # pylint:disable=too-few-public-methods
|
||||
""" SqueezeNet ported from torchvision version.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This port only contains the features portion of the model.
|
||||
|
||||
Reference
|
||||
---------
|
||||
https://arxiv.org/abs/1602.07360
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_shape, Tuple, optional
|
||||
The input shape for the model. Default: ``None``
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _fire(cls,
|
||||
inputs: Tensor,
|
||||
squeeze_planes: int,
|
||||
expand_planes: int,
|
||||
block_idx: int) -> Tensor:
|
||||
""" The fire block for SqueezeNet.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: :class:`plaidml.tile.Value` or :class:`tf.Tensor`
|
||||
The input to the fire block
|
||||
squeeze_planes: int
|
||||
The number of filters for the squeeze convolution
|
||||
expand_planes: int
|
||||
The number of filters for the expand convolutions
|
||||
block_idx: int
|
||||
The index of the current block (for standardized naming convention)
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`plaidml.tile.Value` or :class:`tf.Tensor`
|
||||
The output of the SqueezeNet fire block
|
||||
"""
|
||||
name = f"features.{block_idx}"
|
||||
squeezed = Conv2D(squeeze_planes, 1, activation="relu", name=f"{name}.squeeze")(inputs)
|
||||
expand1 = Conv2D(expand_planes, 1, activation="relu", name=f"{name}.expand1x1")(squeezed)
|
||||
expand3 = Conv2D(expand_planes, 3,
|
||||
activation="relu", padding="same", name=f"{name}.expand3x3")(squeezed)
|
||||
return Concatenate(axis=-1, name=name)([expand1, expand3])
|
||||
|
||||
def __call__(self) -> Model:
|
||||
""" Create the SqueezeNet Model
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`keras.models.Model`
|
||||
The compiled SqueezeNet model
|
||||
"""
|
||||
inputs = Input(self._input_shape)
|
||||
var_x = Conv2D(64, 3, strides=2, activation="relu", name="features.0")(inputs)
|
||||
|
||||
block_idx = 2
|
||||
squeeze = 16
|
||||
expand = 64
|
||||
for idx in range(4):
|
||||
if idx < 3:
|
||||
var_x = MaxPool2D(pool_size=3, strides=2)(var_x)
|
||||
block_idx += 1
|
||||
var_x = self._fire(var_x, squeeze, expand, block_idx)
|
||||
block_idx += 1
|
||||
var_x = self._fire(var_x, squeeze, expand, block_idx)
|
||||
block_idx += 1
|
||||
squeeze += 16
|
||||
expand += 64
|
||||
return Model(inputs=inputs, outputs=[var_x])
|
||||
Loading…
Reference in New Issue
Block a user