mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 19:45:21 +00:00
customize broken weight file test
This commit is contained in:
parent
e8346601e2
commit
005280f684
46
tests/test_commons.py
Normal file
46
tests/test_commons.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# built-in dependencies
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# project dependencies
|
||||||
|
from deepface.commons import folder_utils, weight_utils, package_utils
|
||||||
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
|
logger = Logger()
|
||||||
|
|
||||||
|
tf_version = package_utils.get_tf_major_version()
|
||||||
|
|
||||||
|
if tf_version == 1:
|
||||||
|
from keras.models import Sequential
|
||||||
|
from keras.layers import (
|
||||||
|
Dropout,
|
||||||
|
Dense,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from tensorflow.keras.models import Sequential
|
||||||
|
from tensorflow.keras.layers import (
|
||||||
|
Dropout,
|
||||||
|
Dense,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_loading_broken_weights():
|
||||||
|
home = folder_utils.get_deepface_home()
|
||||||
|
weight_file = os.path.join(home, ".deepface/weights/vgg_face_weights.h5")
|
||||||
|
|
||||||
|
# construct a dummy model
|
||||||
|
model = Sequential()
|
||||||
|
|
||||||
|
# Add layers to the model
|
||||||
|
model.add(
|
||||||
|
Dense(units=64, activation="relu", input_shape=(100,))
|
||||||
|
) # Input layer with 100 features
|
||||||
|
model.add(Dropout(0.5)) # Dropout layer to prevent overfitting
|
||||||
|
model.add(Dense(units=32, activation="relu")) # Hidden layer
|
||||||
|
model.add(Dense(units=10, activation="softmax")) # Output layer with 10 classes
|
||||||
|
|
||||||
|
# vgg's weights cannot be loaded to this model
|
||||||
|
with pytest.raises(ValueError, match="Exception while loading pre-trained weights from"):
|
||||||
|
model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
|
||||||
|
|
||||||
|
logger.info("✅ test loading broken weight file is done")
|
@ -1,13 +1,9 @@
|
|||||||
# built-in dependencies
|
|
||||||
import os
|
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import pytest
|
import pytest
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
# project dependencies
|
# project dependencies
|
||||||
from deepface import DeepFace
|
from deepface import DeepFace
|
||||||
from deepface.commons import folder_utils, package_utils
|
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
logger = Logger()
|
logger = Logger()
|
||||||
@ -192,35 +188,3 @@ def test_verify_for_nested_embeddings():
|
|||||||
_ = DeepFace.verify(img1_path=img1_embeddings, img2_path=img2_path)
|
_ = DeepFace.verify(img1_path=img1_embeddings, img2_path=img2_path)
|
||||||
|
|
||||||
logger.info("✅ test verify for nested embeddings is done")
|
logger.info("✅ test verify for nested embeddings is done")
|
||||||
|
|
||||||
|
|
||||||
def test_verify_for_broken_weights():
|
|
||||||
home = folder_utils.get_deepface_home()
|
|
||||||
|
|
||||||
# we are not performing anything with model deepid
|
|
||||||
|
|
||||||
weights_file = os.path.join(home, ".deepface/weights/deepid_keras_weights.h5")
|
|
||||||
backup_file = os.path.join(home, ".deepface/weights/deepid_keras_weights_backup.h5")
|
|
||||||
|
|
||||||
restore = False
|
|
||||||
# backup original weight file
|
|
||||||
if os.path.exists(weights_file) is True:
|
|
||||||
os.rename(weights_file, backup_file)
|
|
||||||
restore = True
|
|
||||||
|
|
||||||
# Create a dummy vgg_face_weights.h5 file
|
|
||||||
with open(weights_file, "w", encoding="UTF-8") as f:
|
|
||||||
f.write("dummy content")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Exception while loading pre-trained weights from"):
|
|
||||||
_ = DeepFace.verify(
|
|
||||||
img1_path="dataset/img1.jpg",
|
|
||||||
img2_path="dataset/img2.jpg",
|
|
||||||
model_name="DeepId",
|
|
||||||
)
|
|
||||||
|
|
||||||
if restore:
|
|
||||||
os.remove(weights_file)
|
|
||||||
os.rename(backup_file, weights_file)
|
|
||||||
|
|
||||||
logger.info("✅ test verify for broken weight file is done")
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user