mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
Sort by face - Inherit allow_growth option from extract.ini
This commit is contained in:
parent
6595cdf062
commit
30acca8541
|
|
@ -38,9 +38,9 @@ class VGGFace2():
|
|||
https://creativecommons.org/licenses/by-nc/4.0/
|
||||
"""
|
||||
|
||||
def __init__(self, backend="GPU", loglevel="INFO"):
|
||||
logger.debug("Initializing %s: (backend: %s, loglevel: %s)",
|
||||
self.__class__.__name__, backend, loglevel)
|
||||
def __init__(self, backend="GPU", allow_growth=False, loglevel="INFO"):
|
||||
logger.debug("Initializing %s: (backend: %s, allow_growth: %s, loglevel: %s)",
|
||||
self.__class__.__name__, backend, allow_growth, loglevel)
|
||||
backend = backend.upper()
|
||||
git_model_id = 10
|
||||
model_filename = ["vggface2_resnet50_v2.h5"]
|
||||
|
|
@ -48,12 +48,12 @@ class VGGFace2():
|
|||
# Average image provided in https://github.com/ox-vgg/vgg_face2
|
||||
self.average_img = np.array([91.4953, 103.8827, 131.0912])
|
||||
|
||||
self.model = self._get_model(git_model_id, model_filename, backend)
|
||||
self.model = self._get_model(git_model_id, model_filename, backend, allow_growth)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
# <<< GET MODEL >>> #
|
||||
@staticmethod
|
||||
def _get_model(git_model_id, model_filename, backend):
|
||||
def _get_model(git_model_id, model_filename, backend, allow_growth):
|
||||
""" Check if model is available, if not, download and unzip it
|
||||
|
||||
Parameters
|
||||
|
|
@ -66,6 +66,8 @@ class VGGFace2():
|
|||
information)
|
||||
backend: ['GPU', 'CPU']
|
||||
Whether to run inference on a GPU or on the CPU
|
||||
allow_growth: bool
|
||||
``True`` if Tensorflow's allow_growth option should be set, otherwise ``False``
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
|
@ -78,6 +80,19 @@ class VGGFace2():
|
|||
if os.environ.get("KERAS_BACKEND", "") == "plaidml.keras.backend":
|
||||
logger.info("Switching to tensorflow backend.")
|
||||
os.environ["KERAS_BACKEND"] = "tensorflow"
|
||||
|
||||
if allow_growth:
|
||||
# TODO This needs to be centralized. Just a hacky fix to read the allow growth config
|
||||
# option from the Extraction config file
|
||||
logger.info("Enabling Tensorflow 'allow_growth' option")
|
||||
import tensorflow as tf
|
||||
from keras.backend.tensorflow_backend import set_session
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
config.gpu_options.visible_device_list = "0"
|
||||
set_session(tf.Session(config=config))
|
||||
logger.debug("Set Tensorflow 'allow_growth' option")
|
||||
|
||||
import keras
|
||||
from lib.model.layers import L2_normalize
|
||||
if backend == "CPU":
|
||||
|
|
|
|||
|
|
@ -17,8 +17,10 @@ from tqdm import tqdm
|
|||
from lib.serializer import get_serializer_from_filename
|
||||
from lib.faces_detect import DetectedFace
|
||||
from lib.image import ImagesLoader, read_image
|
||||
from lib.utils import get_backend
|
||||
from lib.vgg_face2_keras import VGGFace2 as VGGFace
|
||||
from plugins.extract.pipeline import Extractor, ExtractMedia
|
||||
from plugins.extract._config import Config
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
|
@ -56,7 +58,13 @@ class Sort():
|
|||
|
||||
# Load VGG Face if sorting by face
|
||||
if self.args.sort_method.lower() == "face":
|
||||
self.vgg_face = VGGFace(backend=self.args.backend, loglevel=self.args.loglevel)
|
||||
conf = Config("global", configfile=self.args.configfile)
|
||||
allow_growth = (conf.config_dict["allow_growth"] and
|
||||
self.args.backend.lower() == "gpu" and
|
||||
get_backend() == "nvidia")
|
||||
self.vgg_face = VGGFace(backend=self.args.backend,
|
||||
allow_growth=allow_growth,
|
||||
loglevel=self.args.loglevel)
|
||||
|
||||
# If logging is enabled, prepare container
|
||||
if self.args.log_changes:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user