bugfix - FIx rare bug that fails to load configuration files on some windows installs

This commit is contained in:
torzdf 2021-02-17 18:36:30 +00:00
parent 69813de15a
commit 48ca4d1b0e
4 changed files with 48 additions and 83 deletions

View File

@ -10,6 +10,8 @@ from collections import OrderedDict
from configparser import ConfigParser
from importlib import import_module
from lib.utils import full_path_split
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -64,6 +66,49 @@ class FaceswapConfig():
"""
raise NotImplementedError
def _defaults_from_plugin(self, plugin_folder):
""" Scan the given plugins folder for config defaults.py files and update the
default configuration.
Parameters
----------
plugin_folder: str
The folder to scan for plugins
"""
for dirpath, _, filenames in os.walk(plugin_folder):
default_files = [fname for fname in filenames if fname.endswith("_defaults.py")]
if not default_files:
continue
base_path = os.path.dirname(os.path.realpath(sys.argv[0]))
# Can't use replace as there is a bug on some Windows installs that lowers some paths
import_path = ".".join(full_path_split(dirpath[len(base_path):])[1:])
plugin_type = import_path.split(".")[-1]
for filename in default_files:
self._load_defaults_from_module(filename, import_path, plugin_type)
def _load_defaults_from_module(self, filename, module_path, plugin_type):
""" Load the plugin's defaults module, extract defaults and add to default configuration.
Parameters
----------
filename: str
The filename to load the defaults from
module_path: str
The path to load the module from
plugin_type: str
The type of plugin that the defaults are being loaded for
"""
logger.debug("Adding defaults: (filename: %s, module_path: %s, plugin_type: %s",
filename, module_path, plugin_type)
module = os.path.splitext(filename)[0]
section = ".".join((plugin_type, module.replace("_defaults", "")))
logger.debug("Importing defaults module: %s.%s", module_path, module)
mod = import_module("{}.{}".format(module_path, module))
self.add_section(title=section, info=mod._HELPTEXT) # pylint:disable=protected-access
for key, val in mod._DEFAULTS.items(): # pylint:disable=protected-access
self.add_item(section=section, title=key, **val)
logger.debug("Added defaults: %s", section)
@property
def config_dict(self):
""" Collate global options and requested section into a dictionary with the correct

View File

@ -3,12 +3,8 @@
import logging
import os
import sys
from importlib import import_module
from lib.config import FaceswapConfig
from lib.utils import full_path_split
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -18,27 +14,4 @@ class Config(FaceswapConfig):
def set_defaults(self):
""" Set the default values for config """
logger.debug("Setting defaults")
current_dir = os.path.dirname(__file__)
for dirpath, _, filenames in os.walk(current_dir):
default_files = [fname for fname in filenames if fname.endswith("_defaults.py")]
if not default_files:
continue
base_path = os.path.dirname(os.path.realpath(sys.argv[0]))
import_path = ".".join(full_path_split(dirpath.replace(base_path, ""))[1:])
plugin_type = import_path.split(".")[-1]
for filename in default_files:
self.load_module(filename, import_path, plugin_type)
def load_module(self, filename, module_path, plugin_type):
""" Load the defaults module and add defaults """
logger.debug("Adding defaults: (filename: %s, module_path: %s, plugin_type: %s",
filename, module_path, plugin_type)
module = os.path.splitext(filename)[0]
section = ".".join((plugin_type, module.replace("_defaults", "")))
logger.debug("Importing defaults module: %s.%s", module_path, module)
mod = import_module("{}.{}".format(module_path, module))
self.add_section(title=section, info=mod._HELPTEXT) # pylint:disable=protected-access
for key, val in mod._DEFAULTS.items(): # pylint:disable=protected-access
self.add_item(section=section, title=key, **val)
logger.debug("Added defaults: %s", section)
self._defaults_from_plugin(os.path.dirname(__file__))

View File

@ -3,11 +3,8 @@
import logging
import os
import sys
from importlib import import_module
from lib.config import FaceswapConfig
from lib.utils import full_path_split
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -19,29 +16,7 @@ class Config(FaceswapConfig):
""" Set the default values for config """
logger.debug("Setting defaults")
self.set_globals()
current_dir = os.path.dirname(__file__)
for dirpath, _, filenames in os.walk(current_dir):
default_files = [fname for fname in filenames if fname.endswith("_defaults.py")]
if not default_files:
continue
base_path = os.path.dirname(os.path.realpath(sys.argv[0]))
import_path = ".".join(full_path_split(dirpath.replace(base_path, ""))[1:])
plugin_type = import_path.split(".")[-1]
for filename in default_files:
self.load_module(filename, import_path, plugin_type)
def load_module(self, filename, module_path, plugin_type):
""" Load the defaults module and add defaults """
logger.debug("Adding defaults: (filename: %s, module_path: %s, plugin_type: %s",
filename, module_path, plugin_type)
module = os.path.splitext(filename)[0]
section = ".".join((plugin_type, module.replace("_defaults", "")))
logger.debug("Importing defaults module: %s.%s", module_path, module)
mod = import_module("{}.{}".format(module_path, module))
self.add_section(title=section, info=mod._HELPTEXT) # pylint:disable=protected-access
for key, val in mod._DEFAULTS.items(): # pylint:disable=protected-access
self.add_item(section=section, title=key, **val)
logger.debug("Added defaults: %s", section)
self._defaults_from_plugin(os.path.dirname(__file__))
def set_globals(self):
"""

View File

@ -3,12 +3,8 @@
import logging
import os
import sys
from importlib import import_module
from lib.config import FaceswapConfig
from lib.utils import full_path_split
from plugins.plugin_loader import PluginLoader
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -25,16 +21,7 @@ class Config(FaceswapConfig):
logger.debug("Setting defaults")
self._set_globals()
self._set_loss()
current_dir = os.path.dirname(__file__)
for dirpath, _, filenames in os.walk(current_dir):
default_files = [fname for fname in filenames if fname.endswith("_defaults.py")]
if not default_files:
continue
base_path = os.path.dirname(os.path.realpath(sys.argv[0]))
import_path = ".".join(full_path_split(dirpath.replace(base_path, ""))[1:])
plugin_type = import_path.split(".")[-1]
for filename in default_files:
self.load_module(filename, import_path, plugin_type)
self._defaults_from_plugin(os.path.dirname(__file__))
def _set_globals(self):
""" Set the global options for training """
@ -386,18 +373,3 @@ class Config(FaceswapConfig):
info="Dedicate a portion of the model to learning how to duplicate the input "
"mask. Increases VRAM usage in exchange for learning a quick ability to try "
"to replicate more complex mask models.")
def load_module(self, filename, module_path, plugin_type):
""" Load the defaults module and add defaults """
logger.debug("Adding defaults: (filename: %s, module_path: %s, plugin_type: %s",
filename, module_path, plugin_type)
module = os.path.splitext(filename)[0]
section = ".".join((plugin_type, module.replace("_defaults", "")))
logger.debug("Importing defaults module: %s.%s", module_path, module)
mod = import_module("{}.{}".format(module_path, module))
helptext = mod._HELPTEXT # pylint:disable=protected-access
helptext += ADDITIONAL_INFO if module_path.endswith("model") else ""
self.add_section(title=section, info=helptext)
for key, val in mod._DEFAULTS.items(): # pylint:disable=protected-access
self.add_item(section=section, title=key, **val)
logger.debug("Added defaults: %s", section)