diff --git a/tests/test_extract_faces.py b/tests/test_extract_faces.py index 68939f3..c61f1b5 100644 --- a/tests/test_extract_faces.py +++ b/tests/test_extract_faces.py @@ -104,6 +104,7 @@ def test_batch_extract_faces(detector_backend): "dataset/img11.jpg", "dataset/couple.jpg" ] + expected_num_faces = [1, 1, 1, 2] # Extract faces one by one imgs_objs_individual = [ @@ -121,15 +122,12 @@ def test_batch_extract_faces(detector_backend): align=True, ) - assert ( - len(imgs_objs_batch) == 4 and - all(isinstance(obj, list) for obj in imgs_objs_batch) - ) - assert all( - len(imgs_objs_batch[i]) == 1 - for i in range(len(imgs_objs_batch[:-1])) - ) - assert len(imgs_objs_batch[-1]) == 2 + # Check that the batch extraction returned the expected number of face lists + assert len(imgs_objs_batch) == len(img_paths) + + # Check that each face list has the expected number of faces + for i, expected_faces in enumerate(expected_num_faces): + assert len(imgs_objs_batch[i]) == expected_faces for img_objs_individual, img_objs_batch in zip(imgs_objs_individual, imgs_objs_batch): assert len(img_objs_batch) == len(img_objs_individual), (