enforce broken weight file test at the end of all tests

This commit is contained in:
Sefik Ilkin Serengil 2024-08-31 17:31:00 +01:00
parent ed8a6404d9
commit 074e81ba50
2 changed files with 35 additions and 3 deletions

View File

@ -1,3 +1,6 @@
# built-in dependencies
import hashlib
# 3rd party dependencies
import tensorflow as tf
@ -44,3 +47,19 @@ def validate_for_keras3():
"tf-keras package. Please run `pip install tf-keras` "
"or downgrade your tensorflow."
) from err
def find_file_hash(file_path: str, hash_algorithm: str = "sha256") -> str:
"""
Find the hash of a given file with its content
Args:
file_path (str): exact path of a given file
hash_algorithm (str): hash algorithm
Returns:
hash (str)
"""
hash_func = hashlib.new(hash_algorithm)
with open(file_path, "rb") as f:
while chunk := f.read(8192):
hash_func.update(chunk)
return hash_func.hexdigest()

View File

@ -7,7 +7,7 @@ import cv2
# project dependencies
from deepface import DeepFace
from deepface.commons import folder_utils
from deepface.commons import folder_utils, package_utils
from deepface.commons.logger import Logger
logger = Logger()
@ -75,6 +75,9 @@ def test_different_facial_recognition_models():
logger.info(f"✅ facial recognition models test passed with {coverage_score}")
# test_different_facial_recognition_models takes long time. run broken weight test after it.
verify_for_broken_weights()
def test_different_face_detectors():
for detector in detectors:
@ -194,13 +197,23 @@ def test_verify_for_nested_embeddings():
logger.info("✅ test verify for nested embeddings is done")
def test_verify_for_broken_weights():
def verify_for_broken_weights():
home = folder_utils.get_deepface_home()
weights_file = os.path.join(home, ".deepface/weights/vgg_face_weights.h5")
backup_file = os.path.join(home, ".deepface/weights/vgg_face_weights_backup.h5")
assert os.path.exists(weights_file) is True
# confirm that weight file is available
if os.path.exists(weights_file) is False:
_ = DeepFace.verify(
img1_path="dataset/img1.jpg",
img2_path="dataset/img2.jpg",
model_name="VGG-Face",
)
# confirm that weith file is not broken
weights_file_hash = package_utils.find_file_hash(weights_file)
assert "759266b9614d0fd5d65b97bf716818b746cc77ab5944c7bffc937c6ba9455d8c" == weights_file_hash
# backup original weight file
os.rename(weights_file, backup_file)