From 9e12c92d8a00623019d8587d236acdfa2f7527d2 Mon Sep 17 00:00:00 2001 From: galthran-wq Date: Tue, 11 Feb 2025 17:35:53 +0000 Subject: [PATCH] refactor test --- tests/test_represent.py | 71 +++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/tests/test_represent.py b/tests/test_represent.py index f09834e..3ac65fa 100644 --- a/tests/test_represent.py +++ b/tests/test_represent.py @@ -3,6 +3,7 @@ import io import cv2 import pytest import numpy as np +import pytest # project dependencies from deepface import DeepFace @@ -84,7 +85,19 @@ def test_max_faces(): assert len(results) == max_faces -def test_batched_represent(): +@pytest.mark.parametrize("model_name", [ + "VGG-Face", + "Facenet", + "Facenet512", + "OpenFace", + "DeepFace", + "DeepID", + "Dlib", + "ArcFace", + "SFace", + "GhostFaceNet" +]) +def test_batched_represent(model_name): img_paths = [ "dataset/img1.jpg", "dataset/img2.jpg", @@ -93,38 +106,26 @@ def test_batched_represent(): "dataset/img5.jpg", ] - def _test_for_model(model_name: str): - embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name) - assert len(embedding_objs) == len(img_paths) - if model_name == "VGG-Face": - for embedding_obj in embedding_objs: - embedding = embedding_obj["embedding"] - logger.debug(f"Function returned {len(embedding)} dimensional vector") - assert len(embedding) == 4096 - embedding_objs_one_by_one = [ - embedding_obj - for img_path in img_paths - for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name) - ] - for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs): - assert np.allclose( - embedding_obj_one_by_one["embedding"], - embedding_obj["embedding"], - rtol=1e-2, - atol=1e-2 - ) + embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name) + assert len(embedding_objs) == len(img_paths), f"Expected {len(img_paths)} embeddings, got {len(embedding_objs)}" - for model_name in [ - "VGG-Face", - "Facenet", - "Facenet512", - "OpenFace", - # "DeepFace", - "DeepID", - # "Dlib", - "ArcFace", - "SFace", - "GhostFaceNet" - ]: - _test_for_model(model_name) - logger.info("✅ test batch represent function done") + if model_name == "VGG-Face": + for embedding_obj in embedding_objs: + embedding = embedding_obj["embedding"] + logger.debug(f"Function returned {len(embedding)} dimensional vector") + assert len(embedding) == 4096, f"Expected embedding of length 4096, got {len(embedding)}" + + embedding_objs_one_by_one = [ + embedding_obj + for img_path in img_paths + for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name) + ] + for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs): + assert np.allclose( + embedding_obj_one_by_one["embedding"], + embedding_obj["embedding"], + rtol=1e-2, + atol=1e-2 + ), "Embeddings do not match within tolerance" + + logger.info(f"✅ test batch represent function for model {model_name} done")