Add AlexNet + SqueezeNet definitions

This commit is contained in:
torzdf 2022-06-18 18:21:41 +01:00
parent 1d434b73a4
commit ef79a3d8cb
5 changed files with 255 additions and 15 deletions

View File

@ -55,7 +55,7 @@ model.losses module
------------------- -------------------
The losses listed here are generated from the docstrings in :mod:`lib.model.losses_tf`, however The losses listed here are generated from the docstrings in :mod:`lib.model.losses_tf`, however
the functions are excactly the same for :mod:`lib.model.losses_plaid`. The correct loss module will the functions are exactly the same for :mod:`lib.model.losses_plaid`. The correct loss module will
be imported as :mod:`lib.model.losses` depending on the backend in use. be imported as :mod:`lib.model.losses` depending on the backend in use.
.. rubric:: Module Summary .. rubric:: Module Summary
@ -63,14 +63,32 @@ be imported as :mod:`lib.model.losses` depending on the backend in use.
.. autosummary:: .. autosummary::
:nosignatures: :nosignatures:
~lib.model.losses_tf.DSSIMObjective ~lib.model.loss.loss_tf.DSSIMObjective
~lib.model.losses_tf.GeneralizedLoss ~lib.model.loss.loss_tf.FocalFrequencyLoss
~lib.model.losses_tf.GMSDLoss ~lib.model.loss.loss_tf.GeneralizedLoss
~lib.model.losses_tf.GradientLoss ~lib.model.loss.loss_tf.GMSDLoss
~lib.model.losses_tf.LInfNorm ~lib.model.loss.loss_tf.GradientLoss
~lib.model.losses_tf.LossWrapper ~lib.model.loss.loss_tf.LaplacianPyramidLoss
~lib.model.loss.loss_tf.LInfNorm
~lib.model.loss.loss_tf.LossWrapper
.. automodule:: lib.model.losses_tf .. automodule:: lib.model.loss.loss_tf
:members:
:undoc-members:
:show-inheritance:
model.nets module
-----------------
.. rubric:: Module Summary
.. autosummary::
:nosignatures:
~lib.model.nets.AlexNet
~lib.model.nets.SqueezeNet
.. automodule:: lib.model.nets
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:

View File

@ -11,6 +11,17 @@ The Train Package handles the Model and Trainer plugins for training models in F
model package model package
============= =============
This package contains various helper functions that plugins can inherit from
.. rubric:: Module Summary
.. autosummary::
:nosignatures:
~plugins.train.model._base.model
~plugins.train.model._base.settings
~plugins.train.model._base.io
model._base.model module model._base.model module
------------------------ ------------------------

View File

@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) # pylint:disable=invalid-name
class DSSIMObjective(): # pylint:disable=too-few-public-methods class DSSIMObjective(): # pylint:disable=too-few-public-methods
""" DSSIM and MS-DSSIM Loss Functions """ DSSIM Loss Function
Difference of Structural Similarity (DSSIM loss function). Difference of Structural Similarity (DSSIM loss function).
@ -678,13 +678,12 @@ class LInfNorm(): # pylint:disable=too-few-public-methods
return loss return loss
class LogCosh(): class LogCosh(): # pylint:disable=too-few-public-methods
"""Logarithm of the hyperbolic cosine of the prediction error. """Logarithm of the hyperbolic cosine of the prediction error.
`log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and to `abs(x) - log(2)`
to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly for large `x`. This means that 'logcosh' works mostly like the mean squared error, but will not
like the mean squared error, but will not be so strongly affected by the be so strongly affected by the occasional wildly incorrect prediction.
occasional wildly incorrect prediction.
""" """
def __call__(self, def __call__(self,
y_true: plaidml.tile.Value, y_true: plaidml.tile.Value,

View File

@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
class DSSIMObjective(): # pylint:disable=too-few-public-methods class DSSIMObjective(): # pylint:disable=too-few-public-methods
""" DSSIM and MS-DSSIM Loss Functions """ DSSIM Loss Functions
Difference of Structural Similarity (DSSIM loss function). Difference of Structural Similarity (DSSIM loss function).

212
lib/model/nets.py Normal file
View File

@ -0,0 +1,212 @@
#!/usr/bin/env python3
""" Ports of existing NN Architecture for use in faceswap.py """
import logging
from typing import Optional, Tuple
from lib.utils import get_backend
if get_backend() == "amd":
from keras.layers import Concatenate, Conv2D, Input, MaxPool2D, ZeroPadding2D
from keras.models import Model
from plaidml.tile import Value as Tensor
else:
# Ignore linting errors from Tensorflow's thoroughly broken import system
from tensorflow.keras.layers import Concatenate, Conv2D, Input, MaxPool2D, ZeroPadding2D # noqa pylint:disable=no-name-in-module,import-error
from tensorflow.keras.models import Model # noqa pylint:disable=no-name-in-module,import-error
from tensorflow import Tensor
logger = logging.getLogger(__name__)
class _net(): # pylint:disable=too-few-public-methods
""" Base class for existing NeuralNet architecture
Notes
-----
All architectures assume channels_last format
Parameters
----------
input_shape, Tuple, optional
The input shape for the model. Default: ``None``
"""
def __init__(self,
input_shape: Optional[Tuple[int, int, int]] = None) -> None:
logger.debug("Initializing: %s (input_shape: %s)", self.__class__.__name__, input_shape)
self._input_shape = (None, None, 3) if input_shape is None else input_shape
assert len(self._input_shape) == 3 and self._input_shape[-1] == 3, (
"Input shape must be in the format (height, width, channels) and the number of "
f"channels must equal 3. Received: {self._input_shape}")
logger.debug("Initialized: %s", self.__class__.__name__)
class AlexNet(_net): # pylint:disable=too-few-public-methods
""" AlexNet ported from torchvision version.
Notes
-----
This port only contains the features portion of the model.
Reference
---------
https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf
Parameters
----------
input_shape, Tuple, optional
The input shape for the model. Default: ``None``
"""
def __init__(self, input_shape: Optional[Tuple[int, int, int]] = None) -> None:
super().__init__(input_shape)
self._feature_indices = [0, 3, 6, 8, 10] # For naming equivalent to PyTorch
self._filters = [64, 192, 384, 256, 256] # Filters at each block
@classmethod
def _conv_block(cls,
inputs: Tensor,
padding: int,
filters: int,
kernel_size: int,
strides: int,
block_idx: int,
max_pool: bool) -> Tensor:
"""
The Convolutional block for AlexNet
Parameters
----------
inputs: :class:`plaidml.tile.Value` or :class:`tf.Tensor`
The input tensor to the block
padding: int
The amount of zero paddin to apply prior to convolution
filters: int
The number of filters to apply during convolution
kernel_size: int
The kernel size of the convolution
strides: int
The number of strides for the convolution
block_idx: int
The index of the current block (for standardized naming convention)
max_pool: bool
``True`` to apply a max pooling layer at the beginning of the block otherwise ``False``
Returns
-------
:class:`plaidml.tile.Value` or :class:`tf.Tensor`
The output of the Convolutional block
"""
name = f"features.{block_idx}"
var_x = inputs
if max_pool:
var_x = MaxPool2D(pool_size=3, strides=2, name=f"{name}.pool")(var_x)
var_x = ZeroPadding2D(padding=padding, name=f"{name}.pad")(var_x)
var_x = Conv2D(filters,
kernel_size=kernel_size,
strides=strides,
padding="valid",
activation="relu",
name=name)(var_x)
return var_x
def __call__(self) -> Model:
""" Create the AlexNet Model
Returns
-------
:class:`keras.models.Model`
The compiled AlexNet model
"""
inputs = Input(self._input_shape)
var_x = inputs
kernel_size = 11
strides = 4
for idx, (filters, block_idx) in enumerate(zip(self._filters, self._feature_indices)):
padding = 2 if idx < 2 else 1
do_max_pool = 0 < idx < 3
var_x = self._conv_block(var_x,
padding,
filters,
kernel_size,
strides,
block_idx,
do_max_pool)
kernel_size = max(3, kernel_size // 2)
strides = 1
return Model(inputs=inputs, outputs=[var_x])
class SqueezeNet(_net): # pylint:disable=too-few-public-methods
""" SqueezeNet ported from torchvision version.
Notes
-----
This port only contains the features portion of the model.
Reference
---------
https://arxiv.org/abs/1602.07360
Parameters
----------
input_shape, Tuple, optional
The input shape for the model. Default: ``None``
"""
@classmethod
def _fire(cls,
inputs: Tensor,
squeeze_planes: int,
expand_planes: int,
block_idx: int) -> Tensor:
""" The fire block for SqueezeNet.
Parameters
----------
inputs: :class:`plaidml.tile.Value` or :class:`tf.Tensor`
The input to the fire block
squeeze_planes: int
The number of filters for the squeeze convolution
expand_planes: int
The number of filters for the expand convolutions
block_idx: int
The index of the current block (for standardized naming convention)
Returns
-------
:class:`plaidml.tile.Value` or :class:`tf.Tensor`
The output of the SqueezeNet fire block
"""
name = f"features.{block_idx}"
squeezed = Conv2D(squeeze_planes, 1, activation="relu", name=f"{name}.squeeze")(inputs)
expand1 = Conv2D(expand_planes, 1, activation="relu", name=f"{name}.expand1x1")(squeezed)
expand3 = Conv2D(expand_planes, 3,
activation="relu", padding="same", name=f"{name}.expand3x3")(squeezed)
return Concatenate(axis=-1, name=name)([expand1, expand3])
def __call__(self) -> Model:
""" Create the SqueezeNet Model
Returns
-------
:class:`keras.models.Model`
The compiled SqueezeNet model
"""
inputs = Input(self._input_shape)
var_x = Conv2D(64, 3, strides=2, activation="relu", name="features.0")(inputs)
block_idx = 2
squeeze = 16
expand = 64
for idx in range(4):
if idx < 3:
var_x = MaxPool2D(pool_size=3, strides=2)(var_x)
block_idx += 1
var_x = self._fire(var_x, squeeze, expand, block_idx)
block_idx += 1
var_x = self._fire(var_x, squeeze, expand, block_idx)
block_idx += 1
squeeze += 16
expand += 64
return Model(inputs=inputs, outputs=[var_x])