mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 12:20:27 +01:00
Remove all saved models from repo
All models now download when required Model downloader can handle multiple files in model
This commit is contained in:
parent
7847c4a5ac
commit
ffd72b32d6
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -1,17 +1,11 @@
|
|||
*
|
||||
!setup.cfg
|
||||
!*.caffemodel
|
||||
!*.dat
|
||||
!*.h5
|
||||
!*.ico
|
||||
!*.inf
|
||||
!*.keep
|
||||
!*.md
|
||||
!*.npy
|
||||
!*.nsi
|
||||
!*.pb
|
||||
!*.png
|
||||
!*.prototxt
|
||||
!*.py
|
||||
!*.txt
|
||||
!.cache
|
||||
|
|
|
|||
46
lib/utils.py
46
lib/utils.py
|
|
@ -225,10 +225,29 @@ class GetModel():
|
|||
""" Check for models in their cache path
|
||||
If available, return the path, if not available, get, unzip and install model
|
||||
|
||||
model_name: The name of the model to be loaded
|
||||
cache_dir: the model cache folder folder of the current plugin calling this class """
|
||||
model_filename: The name of the model to be loaded (see notes below)
|
||||
cache_dir: The model cache folder of the current plugin calling this class
|
||||
IE: The folder that holds the model to be loaded.
|
||||
|
||||
NB: Models must have a certain naming convention:
|
||||
IE: <model_name>_v<version_number>.<extension>
|
||||
EG: s3fd_v1.pb
|
||||
|
||||
Multiple models can exist within the model_filename. They should be passed as a list
|
||||
and follow the same naming convention as above. Any differences in filename should
|
||||
occur AFTER the version number.
|
||||
IE: [<model_name>_v<version_number><differentiating_information>.<extension>]
|
||||
EG: [mtcnn_det_v1.1.py, mtcnn_det_v1.2.py, mtcnn_det_v1.3.py]
|
||||
[resnet_ssd_v1.caffemodel, resnet_ssd_v1.prototext]
|
||||
|
||||
Models to be handled by this class must be added to the _model_id property
|
||||
with their appropriate github identier mapped.
|
||||
See https://github.com/deepfakes-models/faceswap-models for more information
|
||||
"""
|
||||
|
||||
def __init__(self, model_filename, cache_dir):
|
||||
if not isinstance(model_filename, list):
|
||||
model_filename = [model_filename]
|
||||
self.model_filename = model_filename
|
||||
self.cache_dir = cache_dir
|
||||
self.url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download"
|
||||
|
|
@ -244,6 +263,9 @@ class GetModel():
|
|||
# EXTRACT (SECTION 1)
|
||||
"face-alignment-network_2d4": 0,
|
||||
"cnn-facial-landmark": 1,
|
||||
"mtcnn_det": 2,
|
||||
"s3fd": 3,
|
||||
"resnet_ssd": 4,
|
||||
# TRAIN (SECTION 2)
|
||||
# CONVERT (SECTION 3)
|
||||
}
|
||||
|
|
@ -251,29 +273,31 @@ class GetModel():
|
|||
|
||||
@property
|
||||
def _model_full_name(self):
|
||||
""" Return the model version from the filename """
|
||||
retval = os.path.splitext(self.model_filename)[0]
|
||||
""" Return the model full name from the filename(s) """
|
||||
common_prefix = os.path.commonprefix(self.model_filename)
|
||||
retval = os.path.splitext(common_prefix)[0]
|
||||
logger.trace(retval)
|
||||
return retval
|
||||
|
||||
@property
|
||||
def _model_name(self):
|
||||
""" Return the model version from the filename """
|
||||
""" Return the model name from the model full name """
|
||||
retval = self._model_full_name[:self._model_full_name.rfind("_")]
|
||||
logger.trace(retval)
|
||||
return retval
|
||||
|
||||
@property
|
||||
def _model_version(self):
|
||||
""" Return the model version from the filename """
|
||||
""" Return the model version from the model full name """
|
||||
retval = int(self._model_full_name[self._model_full_name.rfind("_") + 2:])
|
||||
logger.trace(retval)
|
||||
return retval
|
||||
|
||||
@property
|
||||
def _model_path(self):
|
||||
""" Return the model path in the cache folder """
|
||||
retval = os.path.join(self.cache_dir, self.model_filename)
|
||||
""" Return the model path(s) in the cache folder """
|
||||
retval = [os.path.join(self.cache_dir, fname) for fname in self.model_filename]
|
||||
retval = retval[0] if len(retval) == 1 else retval
|
||||
logger.trace(retval)
|
||||
return retval
|
||||
|
||||
|
|
@ -286,6 +310,10 @@ class GetModel():
|
|||
|
||||
@property
|
||||
def _model_exists(self):
|
||||
""" Check model(s) exist """
|
||||
if isinstance(self._model_path, list):
|
||||
retval = all(os.path.exists(pth) for pth in self._model_path)
|
||||
else:
|
||||
retval = os.path.exists(self._model_path)
|
||||
logger.trace(retval)
|
||||
return retval
|
||||
|
|
@ -343,6 +371,8 @@ class GetModel():
|
|||
else:
|
||||
logger.error("Failed to download model. Exiting. (Error: '%s', URL: '%s')",
|
||||
str(err), self._url_download)
|
||||
logger.info("You can manually download the model from: %s and unzip the "
|
||||
"contents to: %s", self._url_download, self.cache_dir)
|
||||
exit(1)
|
||||
|
||||
def write_zipfile(self, response):
|
||||
|
|
|
|||
|
|
@ -84,6 +84,9 @@ class Aligner():
|
|||
@staticmethod
|
||||
def get_model(model_filename):
|
||||
""" Check if model is available, if not, download and unzip it """
|
||||
if model_filename is None:
|
||||
logger.debug("No model_filename specified. Returning None")
|
||||
return None
|
||||
cache_path = os.path.join(os.path.dirname(__file__), ".cache")
|
||||
model = GetModel(model_filename, cache_path)
|
||||
return model.model_path
|
||||
|
|
|
|||
0
plugins/extract/detect/.cache/.keep
Normal file
0
plugins/extract/detect/.cache/.keep
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -20,7 +20,7 @@ import cv2
|
|||
import dlib
|
||||
|
||||
from lib.gpu_stats import GPUStats
|
||||
from lib.utils import rotate_landmarks
|
||||
from lib.utils import rotate_landmarks, GetModel
|
||||
from plugins.extract._config import Config
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
|
@ -33,12 +33,11 @@ def get_config(plugin_name):
|
|||
|
||||
class Detector():
|
||||
""" Detector object """
|
||||
def __init__(self, loglevel, rotation=None, min_size=0):
|
||||
logger.debug("Initializing %s: (rotation: %s, min_size: %s)",
|
||||
self.__class__.__name__, rotation, min_size)
|
||||
def __init__(self, loglevel, model_filename=None, rotation=None, min_size=0):
|
||||
logger.debug("Initializing %s: (model_filename: %s, rotation: %s, min_size: %s)",
|
||||
self.__class__.__name__, model_filename, rotation, min_size)
|
||||
self.config = get_config(".".join(self.__module__.split(".")[-2:]))
|
||||
self.loglevel = loglevel
|
||||
self.cachepath = os.path.join(os.path.dirname(__file__), ".cache")
|
||||
self.rotation = self.get_rotation_angles(rotation)
|
||||
self.min_size = min_size
|
||||
self.parent_is_pool = False
|
||||
|
|
@ -50,7 +49,7 @@ class Detector():
|
|||
self.queues = {"in": None, "out": None}
|
||||
|
||||
# Path to model if required
|
||||
self.model_path = self.set_model_path()
|
||||
self.model_path = self.get_model(model_filename)
|
||||
|
||||
# Target image size for passing images through the detector
|
||||
# Set to tuple of dimensions (x, y) or int of pixel count
|
||||
|
|
@ -69,13 +68,6 @@ class Detector():
|
|||
logger.debug("Initialized _base %s", self.__class__.__name__)
|
||||
|
||||
# <<< OVERRIDE METHODS >>> #
|
||||
# These methods must be overriden when creating a plugin
|
||||
@staticmethod
|
||||
def set_model_path():
|
||||
""" path to data file/models
|
||||
override for specific detector """
|
||||
raise NotImplementedError()
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
""" Inititalize the detector
|
||||
Tasks to be run before any detection is performed.
|
||||
|
|
@ -99,6 +91,17 @@ class Detector():
|
|||
exit(1)
|
||||
logger.debug("Detecting Faces (args: %s, kwargs: %s)", args, kwargs)
|
||||
|
||||
# <<< GET MODEL >>> #
|
||||
@staticmethod
|
||||
def get_model(model_filename):
|
||||
""" Check if model is available, if not, download and unzip it """
|
||||
if model_filename is None:
|
||||
logger.debug("No model_filename specified. Returning None")
|
||||
return None
|
||||
cache_path = os.path.join(os.path.dirname(__file__), ".cache")
|
||||
model = GetModel(model_filename, cache_path)
|
||||
return model.model_path
|
||||
|
||||
# <<< DETECTION WRAPPER >>> #
|
||||
def run(self, *args, **kwargs):
|
||||
""" Parent detect process.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
""" OpenCV DNN Face detection plugin """
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -11,23 +10,14 @@ from ._base import cv2, Detector, dlib, logger
|
|||
class Detect(Detector):
|
||||
""" CV2 DNN detector for face recognition """
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
model_filename = ["resnet_ssd_v1.caffemodel", "resnet_ssd_v1.prototxt"]
|
||||
super().__init__(model_filename=model_filename, **kwargs)
|
||||
self.parent_is_pool = True
|
||||
self.target = (300, 300) # Doesn't use VRAM
|
||||
self.vram = 0
|
||||
self.config_file = os.path.join(self.cachepath, "deploy.prototxt")
|
||||
self.detector = None
|
||||
self.confidence = self.config["confidence"] / 100
|
||||
|
||||
def set_model_path(self):
|
||||
""" CV2 DNN model file """
|
||||
model_path = os.path.join(self.cachepath, "res10_300x300_ssd_iter_140000_fp16.caffemodel")
|
||||
if not os.path.exists(model_path):
|
||||
raise Exception("Error: Unable to find {}, reinstall "
|
||||
"the lib!".format(model_path))
|
||||
logger.debug("Loading model: '%s'", model_path)
|
||||
return model_path
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
""" Calculate batch size """
|
||||
super().initialize(*args, **kwargs)
|
||||
|
|
@ -39,8 +29,8 @@ class Detect(Detector):
|
|||
def detect_faces(self, *args, **kwargs):
|
||||
""" Detect faces in grayscale image """
|
||||
super().detect_faces(*args, **kwargs)
|
||||
detector = cv2.dnn.readNetFromCaffe(self.config_file, # pylint: disable=no-member
|
||||
self.model_path)
|
||||
detector = cv2.dnn.readNetFromCaffe(self.model_path[1], # pylint: disable=no-member
|
||||
self.model_path[0])
|
||||
detector.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU) # pylint: disable=no-member
|
||||
while True:
|
||||
item = self.get_item()
|
||||
|
|
|
|||
|
|
@ -9,10 +9,6 @@ class Detect(Detector):
|
|||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_model_path(self):
|
||||
""" No model required for Manual Detector """
|
||||
return None
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
""" Create the mtcnn detector """
|
||||
super().initialize(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ def import_tensorflow():
|
|||
class Detect(Detector):
|
||||
""" MTCNN detector for face recognition """
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
model_filename = ["mtcnn_det_v1.1.npy", "mtcnn_det_v1.2.npy", "mtcnn_det_v1.3.npy"]
|
||||
super().__init__(model_filename=model_filename, **kwargs)
|
||||
self.kwargs = self.validate_kwargs()
|
||||
self.name = "mtcnn"
|
||||
self.target = 2073600 # Uses approx 1.30 GB of VRAM
|
||||
|
|
@ -60,16 +61,6 @@ class Detect(Detector):
|
|||
logger.debug("Using mtcnn kwargs: %s", kwargs)
|
||||
return kwargs
|
||||
|
||||
def set_model_path(self):
|
||||
""" Load the mtcnn models """
|
||||
for model in ("det1.npy", "det2.npy", "det3.npy"):
|
||||
model_path = os.path.join(self.cachepath, model)
|
||||
if not os.path.exists(model_path):
|
||||
raise Exception("Error: Unable to find {}, reinstall "
|
||||
"the lib!".format(model_path))
|
||||
logger.debug("Loading model: '%s'", model_path)
|
||||
return self.cachepath
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
""" Create the mtcnn detector """
|
||||
try:
|
||||
|
|
@ -513,15 +504,15 @@ def create_mtcnn(sess, model_path):
|
|||
with tf.variable_scope('pnet'):
|
||||
data = tf.placeholder(tf.float32, (None, None, None, 3), 'input')
|
||||
pnet = PNet({'data': data})
|
||||
pnet.load(os.path.join(model_path, 'det1.npy'), sess)
|
||||
pnet.load(model_path[0], sess)
|
||||
with tf.variable_scope('rnet'):
|
||||
data = tf.placeholder(tf.float32, (None, 24, 24, 3), 'input')
|
||||
rnet = RNet({'data': data})
|
||||
rnet.load(os.path.join(model_path, 'det2.npy'), sess)
|
||||
rnet.load(model_path[1], sess)
|
||||
with tf.variable_scope('onet'):
|
||||
data = tf.placeholder(tf.float32, (None, 48, 48, 3), 'input')
|
||||
onet = ONet({'data': data})
|
||||
onet.load(os.path.join(model_path, 'det3.npy'), sess)
|
||||
onet.load(model_path[2], sess)
|
||||
|
||||
pnet_fun = lambda img: sess.run(('pnet/conv4-2/BiasAdd:0', # noqa
|
||||
'pnet/prob1:0'),
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ Adapted from S3FD Port in FAN:
|
|||
https://github.com/1adrianb/face-alignment
|
||||
"""
|
||||
|
||||
import os
|
||||
from scipy.special import logsumexp
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -18,22 +17,14 @@ from ._base import Detector, dlib, logger
|
|||
class Detect(Detector):
|
||||
""" S3FD detector for face recognition """
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
model_filename = "s3fd_v1.pb"
|
||||
super().__init__(model_filename=model_filename, **kwargs)
|
||||
self.name = "s3fd"
|
||||
self.target = (640, 640) # Uses approx 4 GB of VRAM
|
||||
self.vram = 4096
|
||||
self.min_vram = 1024 # Will run at this with warnings
|
||||
self.model = None
|
||||
|
||||
def set_model_path(self):
|
||||
""" Load the s3fd model """
|
||||
model_path = os.path.join(self.cachepath, "s3fd.pb")
|
||||
if not os.path.exists(model_path):
|
||||
raise Exception("Error: Unable to find {}, reinstall "
|
||||
"the lib!".format(model_path))
|
||||
logger.debug("Loading model: '%s'", model_path)
|
||||
return model_path
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
""" Create the s3fd detector """
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user