mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 12:20:27 +01:00
620 lines
23 KiB
Python
620 lines
23 KiB
Python
#!/usr/bin python3
|
|
""" Pytest unit tests for :mod:`lib.utils` """
|
|
import os
|
|
import time
|
|
import warnings
|
|
import zipfile
|
|
|
|
from io import StringIO
|
|
from socket import timeout as socket_timeout, error as socket_error
|
|
from shutil import rmtree
|
|
from typing import Any, cast, List, Tuple, Union
|
|
from unittest.mock import MagicMock
|
|
from urllib import error as urlliberror
|
|
|
|
import pytest
|
|
import pytest_mock
|
|
|
|
from lib import utils
|
|
from lib.utils import (
|
|
_Backend, camel_case_split, convert_to_secs, DebugTimes, deprecation_warning, FaceswapError,
|
|
full_path_split, get_backend, get_dpi, get_folder, get_image_paths, get_tf_version, GetModel,
|
|
safe_shutdown, set_backend, set_system_verbosity)
|
|
|
|
from lib.logger import log_setup
|
|
# Need to setup logging to avoid trace/verbose errors
|
|
log_setup("DEBUG", "pytest_utils.log", "PyTest, False")
|
|
|
|
|
|
# pylint:disable=protected-access
|
|
|
|
|
|
# Backend tests
|
|
def test_set_backend(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
""" Test the :func:`~lib.utils.set_backend` function
|
|
|
|
Parameters
|
|
----------
|
|
monkeypatch: :class:`pytest.MonkeyPatch`
|
|
Monkey patching _FS_BACKEND
|
|
"""
|
|
monkeypatch.setattr(utils, "_FS_BACKEND", "cpu") # _FS_BACKEND already defined
|
|
set_backend("directml")
|
|
assert utils._FS_BACKEND == "directml"
|
|
monkeypatch.delattr(utils, "_FS_BACKEND") # _FS_BACKEND is not already defined
|
|
set_backend("amd")
|
|
assert utils._FS_BACKEND == "amd"
|
|
|
|
|
|
def test_get_backend(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
""" Test the :func:`~lib.utils.get_backend` function
|
|
|
|
Parameters
|
|
----------
|
|
monkeypatch: :class:`pytest.MonkeyPatch`
|
|
Monkey patching _FS_BACKEND
|
|
"""
|
|
monkeypatch.setattr(utils, "_FS_BACKEND", "apple-silicon")
|
|
assert get_backend() == "apple-silicon"
|
|
|
|
|
|
def test__backend(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
""" Test the :class:`~lib.utils._Backend` class
|
|
|
|
Parameters
|
|
----------
|
|
monkeypatch: :class:`pytest.MonkeyPatch`
|
|
Monkey patching :func:`os.environ`, :func:`os.path.isfile`, :func:`builtins.open` and
|
|
:func:`builtins.input`
|
|
"""
|
|
monkeypatch.setattr("os.environ", {"FACESWAP_BACKEND": "nvidia"}) # Environment variable set
|
|
backend = _Backend()
|
|
assert backend.backend == "nvidia"
|
|
|
|
monkeypatch.setattr("os.environ", {}) # Environment variable not set, dummy in config file
|
|
monkeypatch.setattr("os.path.isfile", lambda x: True)
|
|
monkeypatch.setattr("builtins.open", lambda *args, **kwargs: StringIO('{"backend": "amd"}'))
|
|
backend = _Backend()
|
|
assert backend.backend == "amd"
|
|
|
|
monkeypatch.setattr("os.path.isfile", lambda x: False) # no config file, dummy in user input
|
|
monkeypatch.setattr("builtins.input", lambda x: "3")
|
|
backend = _Backend()
|
|
assert backend._configure_backend() == "nvidia"
|
|
|
|
|
|
# Folder and path utils
|
|
def test_get_folder(tmp_path: str) -> None:
|
|
""" Unit test for :func:`~lib.utils.get_folder`
|
|
|
|
Parameters
|
|
----------
|
|
tmp_path: str
|
|
pytest temporary path to generate folders
|
|
"""
|
|
# New folder
|
|
path = os.path.join(tmp_path, "test_new_folder")
|
|
expected_output = path
|
|
assert not os.path.isdir(path)
|
|
assert get_folder(path) == expected_output
|
|
assert os.path.isdir(path)
|
|
|
|
# Test not creating a new folder when it already exists
|
|
path = os.path.join(tmp_path, "test_new_folder")
|
|
expected_output = path
|
|
assert os.path.isdir(path)
|
|
stats = os.stat(path)
|
|
assert get_folder(path) == expected_output
|
|
assert os.path.isdir(path)
|
|
assert stats == os.stat(path)
|
|
|
|
# Test not creating a new folder when make_folder is False
|
|
path = os.path.join(tmp_path, "test_no_folder")
|
|
expected_output = ""
|
|
assert get_folder(path, make_folder=False) == expected_output
|
|
assert not os.path.isdir(path)
|
|
|
|
|
|
def test_get_image_paths(tmp_path: str) -> None:
|
|
""" Unit test for :func:`~lib.utils.test_get_image_paths`
|
|
|
|
Parameters
|
|
----------
|
|
tmp_path: str
|
|
pytest temporary path to generate folders
|
|
"""
|
|
# Test getting image paths from a folder with no images
|
|
test_folder = os.path.join(tmp_path, "test_image_folder")
|
|
os.makedirs(test_folder)
|
|
assert not get_image_paths(test_folder)
|
|
|
|
# Populate 2 different image files and 1 text file
|
|
test_jpg_path = os.path.join(test_folder, "test_image.jpg")
|
|
test_png_path = os.path.join(test_folder, "test_image.png")
|
|
test_txt_path = os.path.join(test_folder, "test_file.txt")
|
|
for fname in (test_jpg_path, test_png_path, test_txt_path):
|
|
with open(fname, "a", encoding="utf-8"):
|
|
pass
|
|
|
|
# Test getting any image paths from a folder with images and random files
|
|
exists = [os.path.join(test_folder, img)
|
|
for img in os.listdir(test_folder) if os.path.splitext(img)[-1] != ".txt"]
|
|
assert sorted(get_image_paths(test_folder)) == sorted(exists)
|
|
|
|
# Test getting image paths from a folder with images with a specific extension
|
|
exists = [os.path.join(test_folder, img)
|
|
for img in os.listdir(test_folder) if os.path.splitext(img)[-1] == ".png"]
|
|
assert sorted(get_image_paths(test_folder, extension=".png")) == sorted(exists)
|
|
|
|
|
|
_PARAMS = [("/path/to/file.txt", ["/", "path", "to", "file.txt"]), # Absolute
|
|
("/path/to/directory/", ["/", "path", "to", "directory"]),
|
|
("/path/to/directory", ["/", "path", "to", "directory"]),
|
|
("path/to/file.txt", ["path", "to", "file.txt"]), # Relative
|
|
("path/to/directory/", ["path", "to", "directory"]),
|
|
("path/to/directory", ["path", "to", "directory"]),
|
|
("", []), # Edge cases
|
|
("/", ["/"]),
|
|
(".", ["."]),
|
|
("..", [".."])]
|
|
|
|
|
|
@pytest.mark.parametrize("path,result", _PARAMS, ids=[f'"{p[0]}"' for p in _PARAMS])
|
|
def test_full_path_split(path: str, result: List[str]) -> None:
|
|
""" Test the :func:`~lib.utils.full_path_split` function works correctly
|
|
|
|
Parameters
|
|
----------
|
|
path: str
|
|
The path to test
|
|
result: list
|
|
The expected result from the path
|
|
"""
|
|
split = full_path_split(path)
|
|
assert isinstance(split, list)
|
|
assert split == result
|
|
|
|
|
|
_PARAMS = [("camelCase", ["camel", "Case"]),
|
|
("camelCaseTest", ["camel", "Case", "Test"]),
|
|
("camelCaseTestCase", ["camel", "Case", "Test", "Case"]),
|
|
("CamelCase", ["Camel", "Case"]),
|
|
("CamelCaseTest", ["Camel", "Case", "Test"]),
|
|
("CamelCaseTestCase", ["Camel", "Case", "Test", "Case"]),
|
|
("CAmelCASETestCase", ["C", "Amel", "CASE", "Test", "Case"]),
|
|
("camelcasetestcase", ["camelcasetestcase"]),
|
|
("CAMELCASETESTCASE", ["CAMELCASETESTCASE"]),
|
|
("", [])]
|
|
|
|
|
|
@pytest.mark.parametrize("text, result", _PARAMS, ids=[f'"{p[0]}"' for p in _PARAMS])
|
|
def test_camel_case_split(text: str, result: List[str]) -> None:
|
|
""" Test the :func:`~lib.utils.camel_case_spli` function works correctly
|
|
|
|
Parameters
|
|
----------
|
|
text: str
|
|
The camel case text to test
|
|
result: list
|
|
The expected result from the path
|
|
"""
|
|
split = camel_case_split(text)
|
|
assert isinstance(split, list)
|
|
assert split == result
|
|
|
|
|
|
# General utils
|
|
def test_get_tf_version() -> None:
|
|
""" Test the :func:`~lib.utils.get_tf_version` function version returns correctly in range """
|
|
tf_version = get_tf_version()
|
|
assert (2, 2) <= tf_version < (2, 11)
|
|
|
|
|
|
def test_get_dpi() -> None:
|
|
""" Test the :func:`~lib.utils.get_dpi` function version returns correctly in a sane
|
|
range """
|
|
dpi = get_dpi()
|
|
assert isinstance(dpi, float) or dpi is None
|
|
if dpi is None: # No display detected
|
|
return
|
|
assert dpi > 0
|
|
assert dpi < 600.0
|
|
|
|
|
|
_SECPARAMS = [((1, ), 1), # 1 argument
|
|
((10, ), 10),
|
|
((0, 1), 1),
|
|
((0, 60), 60), # 2 arguments
|
|
((1, 0), 60),
|
|
((1, 1), 61),
|
|
((0, 0, 1), 1),
|
|
((0, 0, 60), 60), # 3 arguments
|
|
((0, 1, 0), 60),
|
|
((1, 0, 0), 3600),
|
|
((1, 1, 1), 3661)]
|
|
|
|
|
|
@pytest.mark.parametrize("args,result", _SECPARAMS, ids=[str(p[0]) for p in _SECPARAMS])
|
|
def test_convert_to_secs(args: Tuple[int, ...], result: int) -> None:
|
|
""" Test the :func:`~lib.utils.convert_to_secs` function works correctly
|
|
|
|
Parameters
|
|
----------
|
|
args: tuple
|
|
Tuple of 1, 2 or 3 integers to pass to the function
|
|
result: int
|
|
The expected results for the args tuple
|
|
"""
|
|
secs = convert_to_secs(*args)
|
|
assert isinstance(secs, int)
|
|
assert secs == result
|
|
|
|
|
|
@pytest.mark.parametrize("log_level", ["DEBUG", "INFO", "WARNING", "ERROR"])
|
|
def test_set_system_verbosity(log_level: str) -> None:
|
|
""" Test the :func:`~lib.utils.set_system_verbosity` function works correctly
|
|
|
|
Parameters
|
|
----------
|
|
log_level: str
|
|
The logging loglevel in upper text format
|
|
"""
|
|
# Set TF Env Variable
|
|
tf_set_level = "0" if log_level == "DEBUG" else "3"
|
|
set_system_verbosity(log_level)
|
|
tf_get_level = os.environ["TF_CPP_MIN_LOG_LEVEL"]
|
|
assert tf_get_level == tf_set_level
|
|
warn_filters = [filt for filt in warnings.filters
|
|
if filt[0] == "ignore"
|
|
and filt[2] in (FutureWarning, DeprecationWarning, UserWarning)]
|
|
# Python Warnings
|
|
# DeprecationWarning is already ignored by default, so there should be 1 warning for debug
|
|
# warning. 3 for the rest
|
|
num_warnings = 1 if log_level == "DEBUG" else 3
|
|
warn_count = len(warn_filters)
|
|
assert warn_count == num_warnings
|
|
|
|
|
|
@pytest.mark.parametrize("additional_info", [None, "additional information"])
|
|
def test_deprecation_warning(caplog: pytest.LogCaptureFixture, additional_info: str) -> None:
|
|
""" Test the :func:`~lib.utils.deprecation_warning` function works correctly
|
|
|
|
Parameters
|
|
----------
|
|
caplog: :class:`pytest.LogCaptureFixture`
|
|
Pytest's log capturing fixture
|
|
additional_info: str
|
|
Additional information to pass to the warning function
|
|
"""
|
|
func_name = "function_name"
|
|
test = f"{func_name} has been deprecated and will be removed from a future update."
|
|
if additional_info:
|
|
test = f"{test} {additional_info}"
|
|
deprecation_warning(func_name, additional_info=additional_info)
|
|
assert test in caplog.text
|
|
|
|
|
|
@pytest.mark.parametrize("got_error", [True, False])
|
|
def test_safe_shutdown(caplog: pytest.LogCaptureFixture, got_error: bool) -> None:
|
|
""" Test the :func:`~lib.utils.safe_shutdown` function works correctly
|
|
|
|
Parameters
|
|
----------
|
|
caplog: :class:`pytest.LogCaptureFixture`
|
|
Pytest's log capturing fixture
|
|
got_error: bool
|
|
The got_error parameter to pass to safe_shutdown
|
|
"""
|
|
caplog.set_level("DEBUG")
|
|
with pytest.raises(SystemExit) as wrapped_exit:
|
|
safe_shutdown(got_error=got_error)
|
|
|
|
exit_value = 1 if got_error else 0
|
|
assert wrapped_exit.typename == "SystemExit"
|
|
assert wrapped_exit.value.code == exit_value
|
|
assert "Safely shutting down" in caplog.messages
|
|
assert "Cleanup complete. Shutting down queue manager and exiting" in caplog.messages
|
|
|
|
|
|
def test_faceswap_error():
|
|
""" Test the :class:`~lib.utils.FaceswapError` raises correctly """
|
|
with pytest.raises(Exception):
|
|
raise FaceswapError
|
|
|
|
|
|
# GetModel class
|
|
@pytest.fixture(name="get_model_instance")
|
|
def fixture_get_model_instance(monkeypatch: pytest.MonkeyPatch,
|
|
tmp_path: pytest.TempdirFactory,
|
|
request: pytest.FixtureRequest) -> GetModel:
|
|
""" Create a fixture of the :class:`~lib.utils.GetModel` object, prevent _get() from running at
|
|
__init__ and point the cache_dir at our local test folder """
|
|
cache_dir = os.path.join(str(tmp_path), "get_model")
|
|
os.mkdir(cache_dir)
|
|
|
|
model_filename = "test_model_file_v1.h5"
|
|
git_model_id = 123
|
|
|
|
original_get = GetModel._get
|
|
# Patch out _get() so it is not called from __init__()
|
|
monkeypatch.setattr(utils.GetModel, "_get", lambda x: None)
|
|
model_instance = GetModel(model_filename, git_model_id)
|
|
# Reinsert _get() so we can test it
|
|
monkeypatch.setattr(model_instance, "_get", original_get)
|
|
model_instance._cache_dir = cache_dir
|
|
|
|
def teardown():
|
|
rmtree(cache_dir)
|
|
|
|
request.addfinalizer(teardown)
|
|
return model_instance
|
|
|
|
|
|
_INPUT = ("test_model_file_v3.h5",
|
|
["test_multi_model_file_v1.1.npy", "test_multi_model_file_v1.2.npy"])
|
|
_EXPECTED = ((["test_model_file_v3.h5"], "test_model_file_v3", "test_model_file", 3),
|
|
(["test_multi_model_file_v1.1.npy", "test_multi_model_file_v1.2.npy"],
|
|
"test_multi_model_file_v1", "test_multi_model_file", 1))
|
|
|
|
|
|
@pytest.mark.parametrize("filename,results", zip(_INPUT, _EXPECTED), ids=[str(i) for i in _INPUT])
|
|
def test_get_model_model_filename_input(
|
|
get_model_instance: GetModel, # pylint:disable=unused-argument
|
|
filename: Union[str, List[str]],
|
|
results: Union[str, List[str]]) -> None:
|
|
""" Test :class:`~lib.utils.GetModel` filename parsing works
|
|
|
|
Parameters
|
|
---------
|
|
get_model_instance: `~lib.utils.GetModel`
|
|
The patched instance of the class
|
|
filename: list or str
|
|
The test filenames
|
|
results: tuple
|
|
The expected results for :attr:`_model_filename`, :attr:`_model_full_name`,
|
|
:attr:`_model_name`, :attr:`_model_version` respectively
|
|
"""
|
|
model = GetModel(filename, 123)
|
|
assert model._model_filename == results[0]
|
|
assert model._model_full_name == results[1]
|
|
assert model._model_name == results[2]
|
|
assert model._model_version == results[3]
|
|
|
|
|
|
def test_get_model_attributes(get_model_instance: GetModel) -> None:
|
|
""" Test :class:`~lib.utils.GetModel` private attributes set correctly
|
|
|
|
Parameters
|
|
---------
|
|
get_model_instance: `~lib.utils.GetModel`
|
|
The patched instance of the class
|
|
"""
|
|
model = get_model_instance
|
|
assert model._git_model_id == 123
|
|
assert model._url_base == ("https://github.com/deepfakes-models/faceswap-models"
|
|
"/releases/download")
|
|
assert model._chunk_size == 1024
|
|
assert model._retries == 6
|
|
|
|
|
|
def test_get_model_properties(get_model_instance: GetModel) -> None:
|
|
""" Test :class:`~lib.utils.GetModel` calculated attributes return correctly
|
|
|
|
Parameters
|
|
---------
|
|
get_model_instance: `~lib.utils.GetModel`
|
|
The patched instance of the class
|
|
"""
|
|
model = get_model_instance
|
|
assert model.model_path == os.path.join(model._cache_dir, "test_model_file_v1.h5")
|
|
assert model._model_zip_path == os.path.join(model._cache_dir, "test_model_file_v1.zip")
|
|
assert not model._model_exists
|
|
assert model._url_download == ("https://github.com/deepfakes-models/faceswap-models/releases/"
|
|
"download/v123.1/test_model_file_v1.zip")
|
|
assert model._url_partial_size == 0
|
|
|
|
|
|
@pytest.mark.parametrize("model_exists", (True, False))
|
|
def test_get_model__get(mocker: pytest_mock.MockerFixture,
|
|
get_model_instance: GetModel,
|
|
model_exists: bool) -> None:
|
|
""" Test :func:`~lib.utils.GetModel._get` executes logic correctly
|
|
|
|
Parameters
|
|
---------
|
|
mocker: :class:`pytest_mock.MockerFixture`
|
|
Mocker for dummying in function calls
|
|
get_model_instance: `~lib.utils.GetModel`
|
|
The patched instance of the class
|
|
model_exists: bool
|
|
For testing the function when a model exists and when it does not
|
|
"""
|
|
model = get_model_instance
|
|
model._download_model = cast(MagicMock, mocker.MagicMock()) # type:ignore
|
|
model._unzip_model = cast(MagicMock, mocker.MagicMock()) # type:ignore
|
|
os_remove = mocker.patch("os.remove")
|
|
|
|
if model_exists: # Dummy in a model file
|
|
assert isinstance(model.model_path, str)
|
|
with open(model.model_path, "a", encoding="utf-8"):
|
|
pass
|
|
|
|
model._get(model) # type:ignore
|
|
|
|
assert (model_exists and not model._download_model.called) or (
|
|
not model_exists and model._download_model.called)
|
|
assert (model_exists and not model._unzip_model.called) or (
|
|
not model_exists and model._unzip_model.called)
|
|
assert model_exists or not (model_exists and os_remove.called)
|
|
os_remove.reset_mock()
|
|
|
|
|
|
_DLPARAMS = [(None, None),
|
|
(socket_error, ()),
|
|
(socket_timeout, ()),
|
|
(urlliberror.URLError, ("test_reason", )),
|
|
(urlliberror.HTTPError, ("test_uri", 400, "", "", 0))]
|
|
|
|
|
|
@pytest.mark.parametrize("error_type,error_args", _DLPARAMS, ids=[str(p[0]) for p in _DLPARAMS])
|
|
def test_get_model__download_model(mocker: pytest_mock.MockerFixture,
|
|
get_model_instance: GetModel,
|
|
error_type: Any,
|
|
error_args: Tuple[Union[str, int], ...]) -> None:
|
|
""" Test :func:`~lib.utils.GetModel._download_model` executes its logic correctly
|
|
|
|
Parameters
|
|
---------
|
|
mocker: :class:`pytest_mock.MockerFixture`
|
|
Mocker for dummying in function calls
|
|
get_model_instance: `~lib.utils.GetModel`
|
|
The patched instance of the class
|
|
error_type: connection error type or ``None``
|
|
Connection error type to mock, or ``None`` for succesful download
|
|
error_args: tuple
|
|
The arguments to be passed to the exception to be raised
|
|
"""
|
|
mock_urlopen = mocker.patch("urllib.request.urlopen")
|
|
if not error_type: # Model download is successful
|
|
get_model_instance._write_zipfile = cast(MagicMock, mocker.MagicMock()) # type:ignore
|
|
get_model_instance._download_model()
|
|
assert mock_urlopen.called
|
|
assert get_model_instance._write_zipfile.called
|
|
else: # Test that the process exits on download errors
|
|
mock_urlopen.side_effect = error_type(*error_args)
|
|
with pytest.raises(SystemExit):
|
|
get_model_instance._download_model()
|
|
mock_urlopen.reset_mock()
|
|
|
|
|
|
@pytest.mark.parametrize("dl_type", ["complete", "new", "continue"])
|
|
def test_get_model__write_zipfile(mocker: pytest_mock.MockerFixture,
|
|
get_model_instance: GetModel,
|
|
dl_type: str) -> None:
|
|
""" Test :func:`~lib.utils.GetModel._write_zipfile` executes its logic correctly
|
|
|
|
Parameters
|
|
---------
|
|
mocker: :class:`pytest_mock.MockerFixture`
|
|
Mocker for dummying in function calls
|
|
get_model_instance: `~lib.utils.GetModel`
|
|
The patched instance of the class
|
|
dl_type: str
|
|
The type of read to attemp
|
|
"""
|
|
response = mocker.MagicMock()
|
|
assert not os.path.isfile(get_model_instance._model_zip_path)
|
|
|
|
downloaded = 10 if dl_type == "complete" else 0
|
|
response.getheader.return_value = 0
|
|
|
|
if dl_type in ("new", "continue"):
|
|
chunks = [32, 64, 128, 256, 512, 1024]
|
|
data = [b"\x00" * size for size in chunks] + [b""]
|
|
response.getheader.return_value = sum(chunks)
|
|
response.read.side_effect = data
|
|
|
|
if dl_type == "continue": # Write a partial download of the correct size
|
|
with open(get_model_instance._model_zip_path, "wb") as partial:
|
|
partial.write(b"\x00" * sum(chunks))
|
|
downloaded = os.path.getsize(get_model_instance._model_zip_path)
|
|
|
|
get_model_instance._write_zipfile(response, downloaded)
|
|
|
|
if dl_type == "complete": # Already downloaded. No more tests
|
|
assert not response.read.called
|
|
return
|
|
|
|
assert response.read.call_count == len(data) # all data read
|
|
assert os.path.isfile(get_model_instance._model_zip_path)
|
|
downloaded_size = os.path.getsize(get_model_instance._model_zip_path)
|
|
downloaded_size = downloaded_size if dl_type == "new" else downloaded_size // 2
|
|
assert downloaded_size == sum(chunks)
|
|
|
|
|
|
def test_get_model__unzip_model(mocker: pytest_mock.MockerFixture,
|
|
get_model_instance: GetModel) -> None:
|
|
""" Test :func:`~lib.utils.GetModel._unzip_model` executes its logic correctly
|
|
|
|
Parameters
|
|
---------
|
|
mocker: :class:`pytest_mock.MockerFixture`
|
|
Mocker for dummying in function calls
|
|
get_model_instance: `~lib.utils.GetModel`
|
|
The patched instance of the class
|
|
"""
|
|
mock_zipfile = mocker.patch("zipfile.ZipFile")
|
|
# Successful
|
|
get_model_instance._unzip_model()
|
|
assert mock_zipfile.called
|
|
mock_zipfile.reset_mock()
|
|
# Error
|
|
mock_zipfile.side_effect = zipfile.BadZipFile()
|
|
with pytest.raises(SystemExit):
|
|
get_model_instance._unzip_model()
|
|
mock_zipfile.reset_mock()
|
|
|
|
|
|
def test_get_model__write_model(mocker: pytest_mock.MockerFixture,
|
|
get_model_instance: GetModel) -> None:
|
|
""" Test :func:`~lib.utils.GetModel._write_model` executes its logic correctly
|
|
|
|
Parameters
|
|
---------
|
|
mocker: :class:`pytest_mock.MockerFixture`
|
|
Mocker for dummying in function calls
|
|
get_model_instance: `~lib.utils.GetModel`
|
|
The patched instance of the class
|
|
"""
|
|
out_file = os.path.join(get_model_instance._cache_dir, get_model_instance._model_filename[0])
|
|
chunks = [8, 16, 32, 64, 128, 256, 512, 1024]
|
|
data = [b"\x00" * size for size in chunks] + [b""]
|
|
assert not os.path.isfile(out_file)
|
|
mock_zipfile = mocker.patch("zipfile.ZipFile")
|
|
mock_zipfile.namelist.return_value = get_model_instance._model_filename
|
|
mock_zipfile.open.return_value = mock_zipfile
|
|
mock_zipfile.read.side_effect = data
|
|
get_model_instance._write_model(mock_zipfile)
|
|
assert mock_zipfile.read.call_count == len(data)
|
|
assert os.path.isfile(out_file)
|
|
assert os.path.getsize(out_file) == sum(chunks)
|
|
|
|
|
|
# DebugTimes class
|
|
def test_debug_times():
|
|
""" Test :class:`~lib.utils.DebugTimes` executes its logic correctly """
|
|
debug_times = DebugTimes()
|
|
|
|
debug_times.step_start("Test1")
|
|
time.sleep(0.1)
|
|
debug_times.step_end("Test1")
|
|
|
|
debug_times.step_start("Test2")
|
|
time.sleep(0.2)
|
|
debug_times.step_end("Test2")
|
|
|
|
debug_times.step_start("Test1")
|
|
time.sleep(0.1)
|
|
debug_times.step_end("Test1")
|
|
|
|
debug_times.summary()
|
|
|
|
# Ensure that the summary method prints the min, mean, and max times for each step
|
|
assert debug_times._display["min"] is True
|
|
assert debug_times._display["mean"] is True
|
|
assert debug_times._display["max"] is True
|
|
|
|
# Ensure that the summary method includes the correct number of items for each step
|
|
assert len(debug_times._times["Test1"]) == 2
|
|
assert len(debug_times._times["Test2"]) == 1
|
|
|
|
# Ensure that the summary method includes the correct min, mean, and max times for each step
|
|
assert min(debug_times._times["Test1"]) == pytest.approx(0.1, abs=1e-1)
|
|
assert min(debug_times._times["Test2"]) == pytest.approx(0.2, abs=1e-1)
|
|
assert max(debug_times._times["Test1"]) == pytest.approx(0.1, abs=1e-1)
|
|
assert max(debug_times._times["Test2"]) == pytest.approx(0.2, abs=1e-1)
|
|
assert (sum(debug_times._times["Test1"]) /
|
|
len(debug_times._times["Test1"])) == pytest.approx(0.1, abs=1e-1)
|
|
assert (sum(debug_times._times["Test2"]) /
|
|
len(debug_times._times["Test2"]) == pytest.approx(0.2, abs=1e-1))
|