detect faces return list of lists on batched inputs

This commit is contained in:
galthran-wq 2025-02-13 13:59:16 +00:00
parent c4b4b4a736
commit 60bee4e1a9
3 changed files with 18 additions and 5 deletions

View File

@ -519,7 +519,7 @@ def extract_faces(
color_face: str = "rgb",
normalize_face: bool = True,
anti_spoofing: bool = False,
) -> List[Dict[str, Any]]:
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""
Extract faces from a given image or sequence of images.
@ -551,7 +551,8 @@ def extract_faces(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains:
results (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]):
A list or a list of lists of dictionaries, where each dictionary contains:
- "face" (np.ndarray): The detected face as a NumPy array.

View File

@ -29,7 +29,7 @@ def extract_faces(
normalize_face: bool = True,
anti_spoofing: bool = False,
max_faces: Optional[int] = None,
) -> List[Dict[str, Any]]:
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""
Extract faces from a given image or list of images
@ -62,7 +62,8 @@ def extract_faces(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains:
results (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]):
A list or list of lists of dictionaries, where each dictionary contains:
- "face" (np.ndarray): The detected face as a NumPy array in RGB format.
@ -131,6 +132,7 @@ def extract_faces(
base_region = FacialAreaRegion(x=0, y=0, w=width, h=height, confidence=0)
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
img_resp_objs = []
for face_obj in face_objs:
current_img = face_obj.img
current_region = face_obj.facial_area
@ -195,8 +197,12 @@ def extract_faces(
resp_obj["is_real"] = is_real
resp_obj["antispoof_score"] = antispoof_score
all_resp_objs.append(resp_obj)
img_resp_objs.append(resp_obj)
all_resp_objs.append(img_resp_objs)
if len(all_resp_objs) == 1:
return all_resp_objs[0]
return all_resp_objs

View File

@ -106,6 +106,12 @@ def test_batch_extract_faces(detector_backend):
align=True,
)
assert (
len(img_objs_batch) == 3 and
all(isinstance(obj, list) and len(obj) == 1 for obj in img_objs_batch)
)
img_objs_batch = [obj for sublist in img_objs_batch for obj in sublist]
assert len(img_objs_batch) == len(img_objs_individual)
for img_obj_individual, img_obj_batch in zip(img_objs_individual, img_objs_batch):