From 6143ed9bb4a9773a32cbe7c7d8d740ec3a9c7a9f Mon Sep 17 00:00:00 2001 From: galthran-wq Date: Fri, 21 Feb 2025 17:35:25 +0000 Subject: [PATCH] clearify test batch extract --- tests/test_extract_faces.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_extract_faces.py b/tests/test_extract_faces.py index 68939f3..c61f1b5 100644 --- a/tests/test_extract_faces.py +++ b/tests/test_extract_faces.py @@ -104,6 +104,7 @@ def test_batch_extract_faces(detector_backend): "dataset/img11.jpg", "dataset/couple.jpg" ] + expected_num_faces = [1, 1, 1, 2] # Extract faces one by one imgs_objs_individual = [ @@ -121,15 +122,12 @@ def test_batch_extract_faces(detector_backend): align=True, ) - assert ( - len(imgs_objs_batch) == 4 and - all(isinstance(obj, list) for obj in imgs_objs_batch) - ) - assert all( - len(imgs_objs_batch[i]) == 1 - for i in range(len(imgs_objs_batch[:-1])) - ) - assert len(imgs_objs_batch[-1]) == 2 + # Check that the batch extraction returned the expected number of face lists + assert len(imgs_objs_batch) == len(img_paths) + + # Check that each face list has the expected number of faces + for i, expected_faces in enumerate(expected_num_faces): + assert len(imgs_objs_batch[i]) == expected_faces for img_objs_individual, img_objs_batch in zip(imgs_objs_individual, imgs_objs_batch): assert len(img_objs_batch) == len(img_objs_individual), (