more shape checks

This commit is contained in:
galthran-wq 2025-02-21 17:40:28 +00:00
parent 6143ed9bb4
commit c46d886a67

View File

@ -23,6 +23,10 @@ def test_different_detectors():
for detector in detectors:
img_objs = DeepFace.extract_faces(img_path=img_path, detector_backend=detector)
# Check return type for non-batch input
assert isinstance(img_objs, list) and all(isinstance(obj, dict) for obj in img_objs)
for img_obj in img_objs:
assert "face" in img_obj.keys()
assert "facial_area" in img_obj.keys()
@ -114,6 +118,15 @@ def test_batch_extract_faces(detector_backend):
align=True,
) for img_path in img_paths
]
# Check that individual extraction returns a list of faces
for img_objs_individual in imgs_objs_individual:
assert isinstance(img_objs_individual, list)
assert all(isinstance(face, dict) for face in img_objs_individual)
# Check that the individual extraction results match the expected number of faces
for img_objs_individual, expected_faces in zip(imgs_objs_individual, expected_num_faces):
assert len(img_objs_individual) == expected_faces
# Extract faces in batch
imgs_objs_batch = DeepFace.extract_faces(
@ -129,6 +142,7 @@ def test_batch_extract_faces(detector_backend):
for i, expected_faces in enumerate(expected_num_faces):
assert len(imgs_objs_batch[i]) == expected_faces
# Check that the individual extraction results match the batch extraction results
for img_objs_individual, img_objs_batch in zip(imgs_objs_individual, imgs_objs_batch):
assert len(img_objs_batch) == len(img_objs_individual), (
"Batch and individual extraction results should have the same number of detected faces"
@ -190,6 +204,18 @@ def test_batch_extract_faces_with_nparray(detector_backend):
align=True,
enforce_detection=False,
)
# Check return type for batch input
assert (
isinstance(imgs_objs_batch, list) and
all(
isinstance(obj, list) and
all(isinstance(face, dict) for face in obj)
for obj in imgs_objs_batch
)
)
# Check that the batch extraction returned the expected number of face lists
assert len(imgs_objs_batch) == 4
for img_objs_batch, expected_num_faces in zip(imgs_objs_batch, expected_num_faces):
assert len(img_objs_batch) == expected_num_faces