Sort by face - Inherit allow_growth option from extract.ini

This commit is contained in:
torzdf 2020-07-04 10:38:46 +01:00
parent 6595cdf062
commit 30acca8541
2 changed files with 29 additions and 6 deletions

View File

@ -38,9 +38,9 @@ class VGGFace2():
https://creativecommons.org/licenses/by-nc/4.0/
"""
def __init__(self, backend="GPU", loglevel="INFO"):
logger.debug("Initializing %s: (backend: %s, loglevel: %s)",
self.__class__.__name__, backend, loglevel)
def __init__(self, backend="GPU", allow_growth=False, loglevel="INFO"):
logger.debug("Initializing %s: (backend: %s, allow_growth: %s, loglevel: %s)",
self.__class__.__name__, backend, allow_growth, loglevel)
backend = backend.upper()
git_model_id = 10
model_filename = ["vggface2_resnet50_v2.h5"]
@ -48,12 +48,12 @@ class VGGFace2():
# Average image provided in https://github.com/ox-vgg/vgg_face2
self.average_img = np.array([91.4953, 103.8827, 131.0912])
self.model = self._get_model(git_model_id, model_filename, backend)
self.model = self._get_model(git_model_id, model_filename, backend, allow_growth)
logger.debug("Initialized %s", self.__class__.__name__)
# <<< GET MODEL >>> #
@staticmethod
def _get_model(git_model_id, model_filename, backend):
def _get_model(git_model_id, model_filename, backend, allow_growth):
""" Check if model is available, if not, download and unzip it
Parameters
@ -66,6 +66,8 @@ class VGGFace2():
information)
backend: ['GPU', 'CPU']
Whether to run inference on a GPU or on the CPU
allow_growth: bool
``True`` if Tensorflow's allow_growth option should be set, otherwise ``False``
See Also
--------
@ -78,6 +80,19 @@ class VGGFace2():
if os.environ.get("KERAS_BACKEND", "") == "plaidml.keras.backend":
logger.info("Switching to tensorflow backend.")
os.environ["KERAS_BACKEND"] = "tensorflow"
if allow_growth:
# TODO This needs to be centralized. Just a hacky fix to read the allow growth config
# option from the Extraction config file
logger.info("Enabling Tensorflow 'allow_growth' option")
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = "0"
set_session(tf.Session(config=config))
logger.debug("Set Tensorflow 'allow_growth' option")
import keras
from lib.model.layers import L2_normalize
if backend == "CPU":

View File

@ -17,8 +17,10 @@ from tqdm import tqdm
from lib.serializer import get_serializer_from_filename
from lib.faces_detect import DetectedFace
from lib.image import ImagesLoader, read_image
from lib.utils import get_backend
from lib.vgg_face2_keras import VGGFace2 as VGGFace
from plugins.extract.pipeline import Extractor, ExtractMedia
from plugins.extract._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -56,7 +58,13 @@ class Sort():
# Load VGG Face if sorting by face
if self.args.sort_method.lower() == "face":
self.vgg_face = VGGFace(backend=self.args.backend, loglevel=self.args.loglevel)
conf = Config("global", configfile=self.args.configfile)
allow_growth = (conf.config_dict["allow_growth"] and
self.args.backend.lower() == "gpu" and
get_backend() == "nvidia")
self.vgg_face = VGGFace(backend=self.args.backend,
allow_growth=allow_growth,
loglevel=self.args.loglevel)
# If logging is enabled, prepare container
if self.args.log_changes: