From 60bee4e1a98a40829f8637a998a0d324c4b46650 Mon Sep 17 00:00:00 2001 From: galthran-wq Date: Thu, 13 Feb 2025 13:59:16 +0000 Subject: [PATCH] detect faces return list of lists on batched inputs --- deepface/DeepFace.py | 5 +++-- deepface/modules/detection.py | 12 +++++++++--- tests/test_extract_faces.py | 6 ++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index eeacbe7..f4253c4 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -519,7 +519,7 @@ def extract_faces( color_face: str = "rgb", normalize_face: bool = True, anti_spoofing: bool = False, -) -> List[Dict[str, Any]]: +) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """ Extract faces from a given image or sequence of images. @@ -551,7 +551,8 @@ def extract_faces( anti_spoofing (boolean): Flag to enable anti spoofing (default is False). Returns: - results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains: + results (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]): + A list or a list of lists of dictionaries, where each dictionary contains: - "face" (np.ndarray): The detected face as a NumPy array. diff --git a/deepface/modules/detection.py b/deepface/modules/detection.py index 5d974df..e3b679d 100644 --- a/deepface/modules/detection.py +++ b/deepface/modules/detection.py @@ -29,7 +29,7 @@ def extract_faces( normalize_face: bool = True, anti_spoofing: bool = False, max_faces: Optional[int] = None, -) -> List[Dict[str, Any]]: +) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """ Extract faces from a given image or list of images @@ -62,7 +62,8 @@ def extract_faces( anti_spoofing (boolean): Flag to enable anti spoofing (default is False). Returns: - results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains: + results (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]): + A list or list of lists of dictionaries, where each dictionary contains: - "face" (np.ndarray): The detected face as a NumPy array in RGB format. @@ -131,6 +132,7 @@ def extract_faces( base_region = FacialAreaRegion(x=0, y=0, w=width, h=height, confidence=0) face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)] + img_resp_objs = [] for face_obj in face_objs: current_img = face_obj.img current_region = face_obj.facial_area @@ -195,8 +197,12 @@ def extract_faces( resp_obj["is_real"] = is_real resp_obj["antispoof_score"] = antispoof_score - all_resp_objs.append(resp_obj) + img_resp_objs.append(resp_obj) + all_resp_objs.append(img_resp_objs) + + if len(all_resp_objs) == 1: + return all_resp_objs[0] return all_resp_objs diff --git a/tests/test_extract_faces.py b/tests/test_extract_faces.py index 24f404e..2c34d94 100644 --- a/tests/test_extract_faces.py +++ b/tests/test_extract_faces.py @@ -106,6 +106,12 @@ def test_batch_extract_faces(detector_backend): align=True, ) + assert ( + len(img_objs_batch) == 3 and + all(isinstance(obj, list) and len(obj) == 1 for obj in img_objs_batch) + ) + img_objs_batch = [obj for sublist in img_objs_batch for obj in sublist] + assert len(img_objs_batch) == len(img_objs_individual) for img_obj_individual, img_obj_batch in zip(img_objs_individual, img_objs_batch):