mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 12:20:27 +01:00
bugfix: mask tool memory leak in batch-mode
This commit is contained in:
parent
d0a8d59812
commit
a076afa910
|
|
@ -4,6 +4,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from multiprocessing import Process
|
||||||
from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
|
from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
|
@ -79,28 +80,28 @@ class Mask(): # pylint:disable=too-few-public-methods
|
||||||
logger.debug("Returning output: '%s' for input: '%s'", retval, input_location)
|
logger.debug("Returning output: '%s' for input: '%s'", retval, input_location)
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
def _get_extractor(self) -> Optional[Extractor]:
|
@staticmethod
|
||||||
""" Obtain a Mask extractor plugin and launch it
|
def _run_mask_process(arguments: Namespace) -> None:
|
||||||
|
""" The mask process to be run in a spawned process.
|
||||||
|
|
||||||
Returns
|
In some instances, batch-mode memory leaks. Launching each job in a separate process
|
||||||
-------
|
prevents this leak.
|
||||||
:class:`plugins.extract.pipeline.Extractor`:
|
|
||||||
The launched Extractor
|
Parameters
|
||||||
|
----------
|
||||||
|
arguments: :class:`argparse.Namespace`
|
||||||
|
The :mod:`argparse` arguments to be used for the given job
|
||||||
"""
|
"""
|
||||||
if self._args.processing == "output":
|
logger.debug("Starting process: (arguments: %s)", arguments)
|
||||||
logger.debug("Update type `output` selected. Not launching extractor")
|
mask = _Mask(arguments)
|
||||||
return None
|
mask.process()
|
||||||
logger.debug("masker: %s", self._args.masker)
|
logger.debug("Finished process: (arguments: %s)", arguments)
|
||||||
extractor = Extractor(None, None, self._args.masker, exclude_gpus=self._args.exclude_gpus)
|
|
||||||
logger.debug(extractor)
|
|
||||||
return extractor
|
|
||||||
|
|
||||||
def process(self) -> None:
|
def process(self) -> None:
|
||||||
""" The entry point for triggering the Extraction Process.
|
""" The entry point for triggering the Extraction Process.
|
||||||
|
|
||||||
Should only be called from :class:`lib.cli.launcher.ScriptExecutor`
|
Should only be called from :class:`lib.cli.launcher.ScriptExecutor`
|
||||||
"""
|
"""
|
||||||
extractor = self._get_extractor()
|
|
||||||
for idx, location in enumerate(self._input_locations):
|
for idx, location in enumerate(self._input_locations):
|
||||||
if self._args.batch_mode:
|
if self._args.batch_mode:
|
||||||
logger.info("Processing job %s of %s: %s",
|
logger.info("Processing job %s of %s: %s",
|
||||||
|
|
@ -115,14 +116,12 @@ class Mask(): # pylint:disable=too-few-public-methods
|
||||||
else:
|
else:
|
||||||
arguments = self._args
|
arguments = self._args
|
||||||
|
|
||||||
if extractor is not None:
|
if len(self._input_locations) > 1:
|
||||||
extractor.launch()
|
proc = Process(target=self._run_mask_process, args=(arguments, ))
|
||||||
|
proc.start()
|
||||||
mask = _Mask(arguments, extractor)
|
proc.join()
|
||||||
mask.process()
|
else:
|
||||||
|
self._run_mask_process(arguments)
|
||||||
if extractor is not None:
|
|
||||||
extractor.reset_phase_index()
|
|
||||||
|
|
||||||
|
|
||||||
class _Mask(): # pylint:disable=too-few-public-methods
|
class _Mask(): # pylint:disable=too-few-public-methods
|
||||||
|
|
@ -136,12 +135,9 @@ class _Mask(): # pylint:disable=too-few-public-methods
|
||||||
----------
|
----------
|
||||||
arguments: :class:`argparse.Namespace`
|
arguments: :class:`argparse.Namespace`
|
||||||
The :mod:`argparse` arguments as passed in from :mod:`tools.py`
|
The :mod:`argparse` arguments as passed in from :mod:`tools.py`
|
||||||
extractor: :class:`plugins.extract.pipeline.Extractor`:
|
|
||||||
The launched Extractor
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, arguments: Namespace, extractor: Optional[Extractor]) -> None:
|
def __init__(self, arguments: Namespace) -> None:
|
||||||
logger.debug("Initializing %s: (arguments: %s, extractor: %s)",
|
logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments)
|
||||||
self.__class__.__name__, arguments, extractor)
|
|
||||||
self._update_type = arguments.processing
|
self._update_type = arguments.processing
|
||||||
self._input_is_faces = arguments.input_type == "faces"
|
self._input_is_faces = arguments.input_type == "faces"
|
||||||
self._mask_type = arguments.masker
|
self._mask_type = arguments.masker
|
||||||
|
|
@ -159,7 +155,7 @@ class _Mask(): # pylint:disable=too-few-public-methods
|
||||||
self._faces_saver: Optional[ImagesSaver] = None
|
self._faces_saver: Optional[ImagesSaver] = None
|
||||||
|
|
||||||
self._alignments = self._get_alignments(arguments)
|
self._alignments = self._get_alignments(arguments)
|
||||||
self._extractor = extractor
|
self._extractor = self._get_extractor(arguments.exclude_gpus)
|
||||||
self._set_correct_mask_type()
|
self._set_correct_mask_type()
|
||||||
self._extractor_input_thread = self._feed_extractor()
|
self._extractor_input_thread = self._feed_extractor()
|
||||||
|
|
||||||
|
|
@ -246,6 +242,27 @@ class _Mask(): # pylint:disable=too-few-public-methods
|
||||||
|
|
||||||
return Alignments(folder, filename=filename)
|
return Alignments(folder, filename=filename)
|
||||||
|
|
||||||
|
def _get_extractor(self, exclude_gpus: List[int]) -> Optional[Extractor]:
|
||||||
|
""" Obtain a Mask extractor plugin and launch it
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
exclude_gpus: list or ``None``
|
||||||
|
A list of indices correlating to connected GPUs that Tensorflow should not use. Pass
|
||||||
|
``None`` to not exclude any GPUs.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
:class:`plugins.extract.pipeline.Extractor`:
|
||||||
|
The launched Extractor
|
||||||
|
"""
|
||||||
|
if self._update_type == "output":
|
||||||
|
logger.debug("Update type `output` selected. Not launching extractor")
|
||||||
|
return None
|
||||||
|
logger.debug("masker: %s", self._mask_type)
|
||||||
|
extractor = Extractor(None, None, self._mask_type, exclude_gpus=exclude_gpus)
|
||||||
|
extractor.launch()
|
||||||
|
logger.debug(extractor)
|
||||||
|
return extractor
|
||||||
|
|
||||||
def _set_correct_mask_type(self):
|
def _set_correct_mask_type(self):
|
||||||
""" Some masks have multiple variants that they can be saved as depending on config options
|
""" Some masks have multiple variants that they can be saved as depending on config options
|
||||||
so update the :attr:`_mask_type` accordingly
|
so update the :attr:`_mask_type` accordingly
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user