mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
batched represent
This commit is contained in:
parent
bb134b25d2
commit
c60152e9a5
@ -2,6 +2,7 @@
|
||||
import io
|
||||
import cv2
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
# project dependencies
|
||||
from deepface import DeepFace
|
||||
@ -81,3 +82,49 @@ def test_max_faces():
|
||||
max_faces = 1
|
||||
results = DeepFace.represent(img_path="dataset/couple.jpg", max_faces=max_faces)
|
||||
assert len(results) == max_faces
|
||||
|
||||
|
||||
def test_batched_represent():
|
||||
img_paths = [
|
||||
"dataset/img1.jpg",
|
||||
"dataset/img2.jpg",
|
||||
"dataset/img3.jpg",
|
||||
"dataset/img4.jpg",
|
||||
"dataset/img5.jpg",
|
||||
]
|
||||
|
||||
def _test_for_model(model_name: str):
|
||||
embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name)
|
||||
assert len(embedding_objs) == len(img_paths)
|
||||
if model_name == "VGG-Face":
|
||||
for embedding_obj in embedding_objs:
|
||||
embedding = embedding_obj["embedding"]
|
||||
logger.debug(f"Function returned {len(embedding)} dimensional vector")
|
||||
assert len(embedding) == 4096
|
||||
embedding_objs_one_by_one = [
|
||||
embedding_obj
|
||||
for img_path in img_paths
|
||||
for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name)
|
||||
]
|
||||
for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs):
|
||||
assert np.allclose(
|
||||
embedding_obj_one_by_one["embedding"],
|
||||
embedding_obj["embedding"],
|
||||
rtol=1e-2,
|
||||
atol=1e-2
|
||||
)
|
||||
|
||||
for model_name in [
|
||||
"VGG-Face",
|
||||
"Facenet",
|
||||
"Facenet512",
|
||||
"OpenFace",
|
||||
# "DeepFace",
|
||||
"DeepID",
|
||||
# "Dlib",
|
||||
"ArcFace",
|
||||
"SFace",
|
||||
"GhostFaceNet"
|
||||
]:
|
||||
_test_for_model(model_name)
|
||||
logger.info("✅ test batch represent function done")
|
||||
|
Loading…
x
Reference in New Issue
Block a user