Merge pull request #1329 from serengil/feat-task-3108-unit-test-for-broken-weight-file

unit test for broken weight file
This commit is contained in:
Sefik Ilkin Serengil 2024-08-31 18:13:48 +01:00 committed by GitHub
commit 2feb703f96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 0 deletions

View File

@ -1,3 +1,6 @@
# built-in dependencies
import hashlib
# 3rd party dependencies
import tensorflow as tf
@ -44,3 +47,19 @@ def validate_for_keras3():
"tf-keras package. Please run `pip install tf-keras` "
"or downgrade your tensorflow."
) from err
def find_file_hash(file_path: str, hash_algorithm: str = "sha256") -> str:
"""
Find the hash of a given file with its content
Args:
file_path (str): exact path of a given file
hash_algorithm (str): hash algorithm
Returns:
hash (str)
"""
hash_func = hashlib.new(hash_algorithm)
with open(file_path, "rb") as f:
while chunk := f.read(8192):
hash_func.update(chunk)
return hash_func.hexdigest()

46
tests/test_commons.py Normal file
View File

@ -0,0 +1,46 @@
# built-in dependencies
import os
import pytest
# project dependencies
from deepface.commons import folder_utils, weight_utils, package_utils
from deepface.commons.logger import Logger
logger = Logger()
tf_version = package_utils.get_tf_major_version()
if tf_version == 1:
from keras.models import Sequential
from keras.layers import (
Dropout,
Dense,
)
else:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
Dropout,
Dense,
)
def test_loading_broken_weights():
home = folder_utils.get_deepface_home()
weight_file = os.path.join(home, ".deepface/weights/vgg_face_weights.h5")
# construct a dummy model
model = Sequential()
# Add layers to the model
model.add(
Dense(units=64, activation="relu", input_shape=(100,))
) # Input layer with 100 features
model.add(Dropout(0.5)) # Dropout layer to prevent overfitting
model.add(Dense(units=32, activation="relu")) # Hidden layer
model.add(Dense(units=10, activation="softmax")) # Output layer with 10 classes
# vgg's weights cannot be loaded to this model
with pytest.raises(ValueError, match="Exception while loading pre-trained weights from"):
model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
logger.info("✅ test loading broken weight file is done")

View File

@ -186,3 +186,5 @@ def test_verify_for_nested_embeddings():
match="When passing img1_path as a list, ensure that all its items are of type float",
):
_ = DeepFace.verify(img1_path=img1_embeddings, img2_path=img2_path)
logger.info("✅ test verify for nested embeddings is done")