mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
bugfix: mask tool memory leak in batch-mode
This commit is contained in:
parent
d0a8d59812
commit
a076afa910
|
|
@ -4,6 +4,7 @@ import logging
|
|||
import os
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from multiprocessing import Process
|
||||
from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import cv2
|
||||
|
|
@ -79,28 +80,28 @@ class Mask(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Returning output: '%s' for input: '%s'", retval, input_location)
|
||||
return retval
|
||||
|
||||
def _get_extractor(self) -> Optional[Extractor]:
|
||||
""" Obtain a Mask extractor plugin and launch it
|
||||
@staticmethod
|
||||
def _run_mask_process(arguments: Namespace) -> None:
|
||||
""" The mask process to be run in a spawned process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`plugins.extract.pipeline.Extractor`:
|
||||
The launched Extractor
|
||||
In some instances, batch-mode memory leaks. Launching each job in a separate process
|
||||
prevents this leak.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arguments: :class:`argparse.Namespace`
|
||||
The :mod:`argparse` arguments to be used for the given job
|
||||
"""
|
||||
if self._args.processing == "output":
|
||||
logger.debug("Update type `output` selected. Not launching extractor")
|
||||
return None
|
||||
logger.debug("masker: %s", self._args.masker)
|
||||
extractor = Extractor(None, None, self._args.masker, exclude_gpus=self._args.exclude_gpus)
|
||||
logger.debug(extractor)
|
||||
return extractor
|
||||
logger.debug("Starting process: (arguments: %s)", arguments)
|
||||
mask = _Mask(arguments)
|
||||
mask.process()
|
||||
logger.debug("Finished process: (arguments: %s)", arguments)
|
||||
|
||||
def process(self) -> None:
|
||||
""" The entry point for triggering the Extraction Process.
|
||||
|
||||
Should only be called from :class:`lib.cli.launcher.ScriptExecutor`
|
||||
"""
|
||||
extractor = self._get_extractor()
|
||||
for idx, location in enumerate(self._input_locations):
|
||||
if self._args.batch_mode:
|
||||
logger.info("Processing job %s of %s: %s",
|
||||
|
|
@ -115,14 +116,12 @@ class Mask(): # pylint:disable=too-few-public-methods
|
|||
else:
|
||||
arguments = self._args
|
||||
|
||||
if extractor is not None:
|
||||
extractor.launch()
|
||||
|
||||
mask = _Mask(arguments, extractor)
|
||||
mask.process()
|
||||
|
||||
if extractor is not None:
|
||||
extractor.reset_phase_index()
|
||||
if len(self._input_locations) > 1:
|
||||
proc = Process(target=self._run_mask_process, args=(arguments, ))
|
||||
proc.start()
|
||||
proc.join()
|
||||
else:
|
||||
self._run_mask_process(arguments)
|
||||
|
||||
|
||||
class _Mask(): # pylint:disable=too-few-public-methods
|
||||
|
|
@ -136,12 +135,9 @@ class _Mask(): # pylint:disable=too-few-public-methods
|
|||
----------
|
||||
arguments: :class:`argparse.Namespace`
|
||||
The :mod:`argparse` arguments as passed in from :mod:`tools.py`
|
||||
extractor: :class:`plugins.extract.pipeline.Extractor`:
|
||||
The launched Extractor
|
||||
"""
|
||||
def __init__(self, arguments: Namespace, extractor: Optional[Extractor]) -> None:
|
||||
logger.debug("Initializing %s: (arguments: %s, extractor: %s)",
|
||||
self.__class__.__name__, arguments, extractor)
|
||||
def __init__(self, arguments: Namespace) -> None:
|
||||
logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments)
|
||||
self._update_type = arguments.processing
|
||||
self._input_is_faces = arguments.input_type == "faces"
|
||||
self._mask_type = arguments.masker
|
||||
|
|
@ -159,7 +155,7 @@ class _Mask(): # pylint:disable=too-few-public-methods
|
|||
self._faces_saver: Optional[ImagesSaver] = None
|
||||
|
||||
self._alignments = self._get_alignments(arguments)
|
||||
self._extractor = extractor
|
||||
self._extractor = self._get_extractor(arguments.exclude_gpus)
|
||||
self._set_correct_mask_type()
|
||||
self._extractor_input_thread = self._feed_extractor()
|
||||
|
||||
|
|
@ -246,6 +242,27 @@ class _Mask(): # pylint:disable=too-few-public-methods
|
|||
|
||||
return Alignments(folder, filename=filename)
|
||||
|
||||
def _get_extractor(self, exclude_gpus: List[int]) -> Optional[Extractor]:
|
||||
""" Obtain a Mask extractor plugin and launch it
|
||||
Parameters
|
||||
----------
|
||||
exclude_gpus: list or ``None``
|
||||
A list of indices correlating to connected GPUs that Tensorflow should not use. Pass
|
||||
``None`` to not exclude any GPUs.
|
||||
Returns
|
||||
-------
|
||||
:class:`plugins.extract.pipeline.Extractor`:
|
||||
The launched Extractor
|
||||
"""
|
||||
if self._update_type == "output":
|
||||
logger.debug("Update type `output` selected. Not launching extractor")
|
||||
return None
|
||||
logger.debug("masker: %s", self._mask_type)
|
||||
extractor = Extractor(None, None, self._mask_type, exclude_gpus=exclude_gpus)
|
||||
extractor.launch()
|
||||
logger.debug(extractor)
|
||||
return extractor
|
||||
|
||||
def _set_correct_mask_type(self):
|
||||
""" Some masks have multiple variants that they can be saved as depending on config options
|
||||
so update the :attr:`_mask_type` accordingly
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user