mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
unit tests added for weight utils
This commit is contained in:
parent
a80d3e8d27
commit
8f9c1e3b73
@ -19,6 +19,8 @@ else:
|
||||
|
||||
logger = Logger()
|
||||
|
||||
ALLOWED_COMPRESS_TYPES = ["zip", "bz2"]
|
||||
|
||||
|
||||
def download_weights_if_necessary(
|
||||
file_name: str, source_url: str, compress_type: Optional[str] = None
|
||||
@ -40,12 +42,15 @@ def download_weights_if_necessary(
|
||||
logger.debug(f"{file_name} is already available at {target_file}")
|
||||
return target_file
|
||||
|
||||
if compress_type is not None and compress_type not in ALLOWED_COMPRESS_TYPES:
|
||||
raise ValueError(f"unimplemented compress type - {compress_type}")
|
||||
|
||||
try:
|
||||
logger.info(f"🔗 {file_name} will be downloaded from {source_url} to {target_file}...")
|
||||
|
||||
if compress_type is None:
|
||||
gdown.download(source_url, target_file, quiet=False)
|
||||
elif compress_type is not None:
|
||||
elif compress_type is not None and compress_type in ALLOWED_COMPRESS_TYPES:
|
||||
gdown.download(source_url, f"{target_file}.{compress_type}", quiet=False)
|
||||
|
||||
except Exception as err:
|
||||
|
@ -1,15 +1,19 @@
|
||||
# built-in dependencies
|
||||
import os
|
||||
from unittest import mock
|
||||
import pytest
|
||||
|
||||
# project dependencies
|
||||
from deepface.commons import folder_utils, weight_utils, package_utils
|
||||
from deepface.commons.logger import Logger
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
|
||||
logger = Logger()
|
||||
|
||||
tf_version = package_utils.get_tf_major_version()
|
||||
|
||||
# conditional imports
|
||||
if tf_version == 1:
|
||||
from keras.models import Sequential
|
||||
from keras.layers import (
|
||||
@ -41,9 +45,208 @@ def test_loading_broken_weights():
|
||||
|
||||
# vgg's weights cannot be loaded to this model
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="An exception occurred while loading the pre-trained weights from"
|
||||
ValueError, match="An exception occurred while loading the pre-trained weights from"
|
||||
):
|
||||
model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
|
||||
|
||||
logger.info("✅ test loading broken weight file is done")
|
||||
|
||||
|
||||
@mock.patch("deepface.commons.folder_utils.get_deepface_home") # Update with your actual module
|
||||
@mock.patch("gdown.download") # Mocking gdown's download function
|
||||
@mock.patch("os.path.isfile") # Mocking os.path.isfile
|
||||
@mock.patch("os.makedirs") # Mocking os.makedirs to avoid FileNotFoundError
|
||||
@mock.patch("zipfile.ZipFile") # Mocking the ZipFile class
|
||||
@mock.patch("bz2.BZ2File") # Mocking the BZ2File class
|
||||
@mock.patch("builtins.open", new_callable=mock.mock_open()) # Mocking open
|
||||
class TestDownloadWeightFeature:
|
||||
def test_download_weights_for_available_file(
|
||||
self,
|
||||
mock_open,
|
||||
mock_zipfile,
|
||||
mock_bz2file,
|
||||
mock_makedirs,
|
||||
mock_isfile,
|
||||
mock_gdown,
|
||||
mock_get_deepface_home,
|
||||
):
|
||||
mock_isfile.return_value = True
|
||||
mock_get_deepface_home.return_value = "/mock/home"
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
source_url = "http://example.com/model_weights.zip"
|
||||
|
||||
result = weight_utils.download_weights_if_necessary(file_name, source_url)
|
||||
|
||||
assert result == os.path.join("/mock/home", ".deepface/weights", file_name)
|
||||
|
||||
mock_gdown.assert_not_called()
|
||||
mock_zipfile.assert_not_called()
|
||||
mock_bz2file.assert_not_called()
|
||||
logger.info("✅ test download weights for available file is done")
|
||||
|
||||
def test_download_weights_if_necessary_gdown_failure(
|
||||
self,
|
||||
mock_open,
|
||||
mock_zipfile,
|
||||
mock_bz2file,
|
||||
mock_makedirs,
|
||||
mock_isfile,
|
||||
mock_gdown,
|
||||
mock_get_deepface_home,
|
||||
):
|
||||
# Setting up the mock return values
|
||||
mock_get_deepface_home.return_value = "/mock/home"
|
||||
mock_isfile.return_value = False # Simulate file not being present
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
source_url = "http://example.com/model_weights.h5"
|
||||
|
||||
# Simulate gdown.download raising an exception
|
||||
mock_gdown.side_effect = Exception("Download failed!")
|
||||
|
||||
# Call the function and check for ValueError
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"⛓️💥 An exception occurred while downloading {file_name} from {source_url}.",
|
||||
):
|
||||
weight_utils.download_weights_if_necessary(file_name, source_url)
|
||||
|
||||
logger.info("✅ test for downloading weights while gdown fails done")
|
||||
|
||||
def test_download_weights_if_necessary_no_compression(
|
||||
self,
|
||||
mock_open,
|
||||
mock_zipfile,
|
||||
mock_bz2file,
|
||||
mock_makedirs,
|
||||
mock_isfile,
|
||||
mock_gdown,
|
||||
mock_get_deepface_home,
|
||||
):
|
||||
# Setting up the mock return values
|
||||
mock_get_deepface_home.return_value = "/mock/home"
|
||||
mock_isfile.return_value = False # Simulate file not being present
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
source_url = "http://example.com/model_weights.h5"
|
||||
|
||||
# Call the function
|
||||
result = weight_utils.download_weights_if_necessary(file_name, source_url)
|
||||
|
||||
# Assert that gdown.download was called with the correct parameters
|
||||
mock_gdown.assert_called_once_with(
|
||||
source_url, "/mock/home/.deepface/weights/model_weights.h5", quiet=False
|
||||
)
|
||||
|
||||
# Assert that the return value is correct
|
||||
assert result == "/mock/home/.deepface/weights/model_weights.h5"
|
||||
|
||||
# Assert that zipfile.ZipFile and bz2.BZ2File were not called
|
||||
mock_zipfile.assert_not_called()
|
||||
mock_bz2file.assert_not_called()
|
||||
|
||||
logger.info("✅ test download weights with no compression is done")
|
||||
|
||||
def test_download_weights_if_necessary_zip(
|
||||
self,
|
||||
mock_open,
|
||||
mock_zipfile,
|
||||
mock_bz2file,
|
||||
mock_makedirs,
|
||||
mock_isfile,
|
||||
mock_gdown,
|
||||
mock_get_deepface_home,
|
||||
):
|
||||
# Setting up the mock return values
|
||||
mock_get_deepface_home.return_value = "/mock/home"
|
||||
mock_isfile.return_value = False # Simulate file not being present
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
source_url = "http://example.com/model_weights.zip"
|
||||
compress_type = "zip"
|
||||
|
||||
# Call the function
|
||||
result = weight_utils.download_weights_if_necessary(file_name, source_url, compress_type)
|
||||
|
||||
# Assert that gdown.download was called with the correct parameters
|
||||
mock_gdown.assert_called_once_with(
|
||||
source_url, "/mock/home/.deepface/weights/model_weights.h5.zip", quiet=False
|
||||
)
|
||||
|
||||
# Simulate the unzipping behavior
|
||||
mock_zipfile.return_value.__enter__.return_value.extractall = mock.Mock()
|
||||
|
||||
# Call the function again to simulate unzipping
|
||||
with mock_zipfile.return_value as zip_ref:
|
||||
zip_ref.extractall("/mock/home/.deepface/weights")
|
||||
|
||||
# Assert that the zip file was unzipped correctly
|
||||
zip_ref.extractall.assert_called_once_with("/mock/home/.deepface/weights")
|
||||
|
||||
# Assert that the return value is correct
|
||||
assert result == "/mock/home/.deepface/weights/model_weights.h5"
|
||||
|
||||
logger.info("✅ test download weights for zip is done")
|
||||
|
||||
def test_download_weights_if_necessary_bz2(
|
||||
self,
|
||||
mock_open,
|
||||
mock_zipfile,
|
||||
mock_bz2file,
|
||||
mock_makedirs,
|
||||
mock_isfile,
|
||||
mock_gdown,
|
||||
mock_get_deepface_home,
|
||||
):
|
||||
|
||||
# Setting up the mock return values
|
||||
mock_get_deepface_home.return_value = "/mock/home"
|
||||
mock_isfile.return_value = False # Simulate file not being present
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
source_url = "http://example.com/model_weights.bz2"
|
||||
compress_type = "bz2"
|
||||
|
||||
# Simulate the download success
|
||||
mock_gdown.return_value = None
|
||||
|
||||
# Simulate the BZ2 file reading behavior
|
||||
mock_bz2file.return_value.__enter__.return_value.read.return_value = b"fake data"
|
||||
|
||||
# Call the function under test
|
||||
result = weight_utils.download_weights_if_necessary(file_name, source_url, compress_type)
|
||||
|
||||
# Assert that gdown.download was called with the correct parameters
|
||||
mock_gdown.assert_called_once_with(
|
||||
source_url, "/mock/home/.deepface/weights/model_weights.h5.bz2", quiet=False
|
||||
)
|
||||
|
||||
# Ensure open() is called once for writing the decompressed data
|
||||
mock_open.assert_called_once_with("/mock/home/.deepface/weights/model_weights.h5", "wb")
|
||||
|
||||
# TODO: find a way to check write is called
|
||||
|
||||
# Assert that the return value is correct
|
||||
assert result == "/mock/home/.deepface/weights/model_weights.h5"
|
||||
|
||||
logger.info("✅ test download weights for bz2 is done")
|
||||
|
||||
def test_download_weights_for_non_supported_compress_type(
|
||||
self,
|
||||
mock_open,
|
||||
mock_zipfile,
|
||||
mock_bz2file,
|
||||
mock_makedirs,
|
||||
mock_isfile,
|
||||
mock_gdown,
|
||||
mock_get_deepface_home,
|
||||
):
|
||||
mock_isfile.return_value = False
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
source_url = "http://example.com/model_weights.bz2"
|
||||
compress_type = "7z"
|
||||
with pytest.raises(ValueError, match="unimplemented compress type - 7z"):
|
||||
_ = weight_utils.download_weights_if_necessary(file_name, source_url, compress_type)
|
||||
logger.info("✅ test download weights for unsupported compress type is done")
|
||||
|
Loading…
x
Reference in New Issue
Block a user