refactor test

This commit is contained in:
galthran-wq 2025-02-11 17:35:53 +00:00
parent 8becc97512
commit 9e12c92d8a

View File

@ -3,6 +3,7 @@ import io
import cv2 import cv2
import pytest import pytest
import numpy as np import numpy as np
import pytest
# project dependencies # project dependencies
from deepface import DeepFace from deepface import DeepFace
@ -84,7 +85,19 @@ def test_max_faces():
assert len(results) == max_faces assert len(results) == max_faces
def test_batched_represent(): @pytest.mark.parametrize("model_name", [
"VGG-Face",
"Facenet",
"Facenet512",
"OpenFace",
"DeepFace",
"DeepID",
"Dlib",
"ArcFace",
"SFace",
"GhostFaceNet"
])
def test_batched_represent(model_name):
img_paths = [ img_paths = [
"dataset/img1.jpg", "dataset/img1.jpg",
"dataset/img2.jpg", "dataset/img2.jpg",
@ -93,38 +106,26 @@ def test_batched_represent():
"dataset/img5.jpg", "dataset/img5.jpg",
] ]
def _test_for_model(model_name: str): embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name)
embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name) assert len(embedding_objs) == len(img_paths), f"Expected {len(img_paths)} embeddings, got {len(embedding_objs)}"
assert len(embedding_objs) == len(img_paths)
if model_name == "VGG-Face":
for embedding_obj in embedding_objs:
embedding = embedding_obj["embedding"]
logger.debug(f"Function returned {len(embedding)} dimensional vector")
assert len(embedding) == 4096
embedding_objs_one_by_one = [
embedding_obj
for img_path in img_paths
for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name)
]
for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs):
assert np.allclose(
embedding_obj_one_by_one["embedding"],
embedding_obj["embedding"],
rtol=1e-2,
atol=1e-2
)
for model_name in [ if model_name == "VGG-Face":
"VGG-Face", for embedding_obj in embedding_objs:
"Facenet", embedding = embedding_obj["embedding"]
"Facenet512", logger.debug(f"Function returned {len(embedding)} dimensional vector")
"OpenFace", assert len(embedding) == 4096, f"Expected embedding of length 4096, got {len(embedding)}"
# "DeepFace",
"DeepID", embedding_objs_one_by_one = [
# "Dlib", embedding_obj
"ArcFace", for img_path in img_paths
"SFace", for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name)
"GhostFaceNet" ]
]: for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs):
_test_for_model(model_name) assert np.allclose(
logger.info("✅ test batch represent function done") embedding_obj_one_by_one["embedding"],
embedding_obj["embedding"],
rtol=1e-2,
atol=1e-2
), "Embeddings do not match within tolerance"
logger.info(f"✅ test batch represent function for model {model_name} done")