diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 06a7bda..9c12a9b 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -521,7 +521,7 @@ def stream( def extract_faces( - img_path: Union[str, np.ndarray, IO[bytes]], + img_path: Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]], detector_backend: str = "opencv", enforce_detection: bool = True, align: bool = True, @@ -530,14 +530,14 @@ def extract_faces( color_face: str = "rgb", normalize_face: bool = True, anti_spoofing: bool = False, -) -> List[Dict[str, Any]]: +) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """ - Extract faces from a given image + Extract faces from a given image or sequence of images. Args: - img_path (str or np.ndarray or IO[bytes]): Path to the first image. Accepts exact image path - as a string, numpy array (BGR), a file object that supports at least `.read` and is - opened in binary mode, or base64 encoded images. + img_path (Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]]): + Path(s) to the image(s). Accepts a string path, a numpy array (BGR), a file object + that supports at least `.read` and is opened in binary mode, or base64 encoded images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', @@ -562,7 +562,8 @@ def extract_faces( anti_spoofing (boolean): Flag to enable anti spoofing (default is False). Returns: - results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains: + results (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]): + A list or a list of lists of dictionaries, where each dictionary contains: - "face" (np.ndarray): The detected face as a NumPy array. diff --git a/deepface/models/Detector.py b/deepface/models/Detector.py index 004f0d3..730c432 100644 --- a/deepface/models/Detector.py +++ b/deepface/models/Detector.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Union from abc import ABC, abstractmethod from dataclasses import dataclass import numpy as np @@ -9,15 +9,20 @@ import numpy as np # pylint: disable=unnecessary-pass, too-few-public-methods, too-many-instance-attributes class Detector(ABC): @abstractmethod - def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]: + def detect_faces( + self, + img: Union[np.ndarray, List[np.ndarray]] + ) -> Union[List["FacialAreaRegion"], List[List["FacialAreaRegion"]]]: """ - Interface for detect and align face + Interface for detect and align faces in a batch of images Args: - img (np.ndarray): pre-loaded image as numpy array + img (Union[np.ndarray, List[np.ndarray]]): + Pre-loaded image as numpy array or a list of those Returns: - results (List[FacialAreaRegion]): A list of FacialAreaRegion objects + results (Union[List[List[FacialAreaRegion]], List[FacialAreaRegion]]): + A list or a list of lists of FacialAreaRegion objects where each object contains: - facial_area (FacialAreaRegion): The facial area region represented @@ -28,6 +33,7 @@ class Detector(ABC): pass +# pylint: disable=unnecessary-pass, too-few-public-methods, too-many-instance-attributes @dataclass class FacialAreaRegion: """ diff --git a/deepface/models/face_detection/CenterFace.py b/deepface/models/face_detection/CenterFace.py index b8fdf6b..523f2bf 100644 --- a/deepface/models/face_detection/CenterFace.py +++ b/deepface/models/face_detection/CenterFace.py @@ -1,6 +1,6 @@ # built-in dependencies import os -from typing import List +from typing import List, Union # 3rd party dependencies import numpy as np @@ -34,12 +34,35 @@ class CenterFaceClient(Detector): return CenterFace(weight_path=weights_path) - def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]: + def detect_faces( + self, + img: Union[np.ndarray, List[np.ndarray]], + ) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]: """ Detect and align face with CenterFace Args: - img (np.ndarray): pre-loaded image as numpy array + img (Union[np.ndarray, List[np.ndarray]]): + pre-loaded image as numpy array or a list of those + + Returns: + results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): + A list or a list of lists of FacialAreaRegion objects + """ + is_batched_input = isinstance(img, list) + if not is_batched_input: + img = [img] + results = [self._process_single_image(single_img) for single_img in img] + if not is_batched_input: + return results[0] + return results + + def _process_single_image(self, single_img: np.ndarray) -> List[FacialAreaRegion]: + """ + Helper function to detect faces in a single image. + + Args: + single_img (np.ndarray): pre-loaded image as numpy array Returns: results (List[FacialAreaRegion]): A list of FacialAreaRegion objects @@ -53,7 +76,7 @@ class CenterFaceClient(Detector): # img, img.shape[0], img.shape[1], threshold=threshold # ) detections, landmarks = self.build_model().forward( - img, img.shape[0], img.shape[1], threshold=threshold + single_img, single_img.shape[0], single_img.shape[1], threshold=threshold ) for i, detection in enumerate(detections): diff --git a/deepface/models/face_detection/Dlib.py b/deepface/models/face_detection/Dlib.py index 26bce84..254a32b 100644 --- a/deepface/models/face_detection/Dlib.py +++ b/deepface/models/face_detection/Dlib.py @@ -1,5 +1,5 @@ # built-in dependencies -from typing import List +from typing import List, Union # 3rd party dependencies import numpy as np @@ -47,10 +47,33 @@ class DlibClient(Detector): detector["sp"] = sp return detector - def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: + def detect_faces( + self, + img: Union[np.ndarray, List[np.ndarray]], + ) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]: """ Detect and align face with dlib + Args: + img (Union[np.ndarray, List[np.ndarray]]): + pre-loaded image as numpy array or a list of those + + Returns: + results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): + A list or a list of lists of FacialAreaRegion objects + """ + is_batched_input = isinstance(img, list) + if not is_batched_input: + img = [img] + results = [self._process_single_image(single_img) for single_img in img] + if not is_batched_input: + return results[0] + return results + + def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]: + """ + Helper function to detect faces in a single image. + Args: img (np.ndarray): pre-loaded image as numpy array diff --git a/deepface/models/face_detection/FastMtCnn.py b/deepface/models/face_detection/FastMtCnn.py index 5259036..7d5ccf5 100644 --- a/deepface/models/face_detection/FastMtCnn.py +++ b/deepface/models/face_detection/FastMtCnn.py @@ -17,10 +17,33 @@ class FastMtCnnClient(Detector): def __init__(self): self.model = self.build_model() - def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: + def detect_faces( + self, + img: Union[np.ndarray, List[np.ndarray]], + ) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]: """ Detect and align face with mtcnn + Args: + img (Union[np.ndarray, List[np.ndarray]]): + pre-loaded image as numpy array or a list of those + + Returns: + results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): + A list or a list of lists of FacialAreaRegion objects + """ + is_batched_input = isinstance(img, list) + if not is_batched_input: + img = [img] + results = [self._process_single_image(single_img) for single_img in img] + if not is_batched_input: + return results[0] + return results + + def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]: + """ + Helper function to detect faces in a single image. + Args: img (np.ndarray): pre-loaded image as numpy array diff --git a/deepface/models/face_detection/MediaPipe.py b/deepface/models/face_detection/MediaPipe.py index 48bc2f8..9fcdbbc 100644 --- a/deepface/models/face_detection/MediaPipe.py +++ b/deepface/models/face_detection/MediaPipe.py @@ -1,6 +1,6 @@ # built-in dependencies import os -from typing import Any, List +from typing import Any, List, Union # 3rd party dependencies import numpy as np @@ -43,10 +43,33 @@ class MediaPipeClient(Detector): ) return face_detection - def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: + def detect_faces( + self, + img: Union[np.ndarray, List[np.ndarray]], + ) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]: """ Detect and align face with mediapipe + Args: + img (Union[np.ndarray, List[np.ndarray]]): + pre-loaded image as numpy array or a list of those + + Returns: + results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): + A list or a list of lists of FacialAreaRegion objects + """ + is_batched_input = isinstance(img, list) + if not is_batched_input: + img = [img] + results = [self._process_single_image(single_img) for single_img in img] + if not is_batched_input: + return results[0] + return results + + def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]: + """ + Helper function to detect faces in a single image. + Args: img (np.ndarray): pre-loaded image as numpy array diff --git a/deepface/models/face_detection/MtCnn.py b/deepface/models/face_detection/MtCnn.py index 014e4a5..3806e99 100644 --- a/deepface/models/face_detection/MtCnn.py +++ b/deepface/models/face_detection/MtCnn.py @@ -1,5 +1,6 @@ # built-in dependencies -from typing import List +import logging +from typing import List, Union # 3rd party dependencies import numpy as np @@ -8,6 +9,8 @@ from mtcnn import MTCNN # project dependencies from deepface.models.Detector import Detector, FacialAreaRegion +logger = logging.getLogger(__name__) + # pylint: disable=too-few-public-methods class MtCnnClient(Detector): """ @@ -16,45 +19,71 @@ class MtCnnClient(Detector): def __init__(self): self.model = MTCNN() + self.supports_batch_detection = self._supports_batch_detection() - def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: + def detect_faces( + self, + img: Union[np.ndarray, List[np.ndarray]] + ) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]: """ - Detect and align face with mtcnn + Detect and align faces with mtcnn for a list of images Args: - img (np.ndarray): pre-loaded image as numpy array + imgs (Union[np.ndarray, List[np.ndarray]]): + pre-loaded image as numpy array or a list of those Returns: - results (List[FacialAreaRegion]): A list of FacialAreaRegion objects + results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): + A list of FacialAreaRegion objects for a single image + or a list of lists of FacialAreaRegion objects for each image """ + is_batched_input = isinstance(img, list) + if not is_batched_input: + img = [img] + resp = [] # mtcnn expects RGB but OpenCV read BGR # img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img_rgb = img[:, :, ::-1] - detections = self.model.detect_faces(img_rgb) + img_rgb = [img[:, :, ::-1] for img in img] + if self.supports_batch_detection: + detections = self.model.detect_faces(img_rgb) + else: + detections = [self.model.detect_faces(single_img) for single_img in img_rgb] - if detections is not None and len(detections) > 0: + for image_detections in detections: + image_resp = [] + if image_detections is not None and len(image_detections) > 0: + for current_detection in image_detections: + x, y, w, h = current_detection["box"] + confidence = current_detection["confidence"] + # mtcnn detector assigns left eye with respect to the observer + # but we are setting it with respect to the person itself + left_eye = current_detection["keypoints"]["right_eye"] + right_eye = current_detection["keypoints"]["left_eye"] - for current_detection in detections: - x, y, w, h = current_detection["box"] - confidence = current_detection["confidence"] - # mtcnn detector assigns left eye with respect to the observer - # but we are setting it with respect to the person itself - left_eye = current_detection["keypoints"]["right_eye"] - right_eye = current_detection["keypoints"]["left_eye"] + facial_area = FacialAreaRegion( + x=x, + y=y, + w=w, + h=h, + left_eye=left_eye, + right_eye=right_eye, + confidence=confidence, + ) - facial_area = FacialAreaRegion( - x=x, - y=y, - w=w, - h=h, - left_eye=left_eye, - right_eye=right_eye, - confidence=confidence, - ) + image_resp.append(facial_area) - resp.append(facial_area) + resp.append(image_resp) + if not is_batched_input: + return resp[0] return resp + + def _supports_batch_detection(self) -> bool: + logger.warning( + "Batch detection is disabled for mtcnn by default " + "since the results are not consistent with single image detection. " + ) + return False diff --git a/deepface/models/face_detection/OpenCv.py b/deepface/models/face_detection/OpenCv.py index 4abb6da..667e8fb 100644 --- a/deepface/models/face_detection/OpenCv.py +++ b/deepface/models/face_detection/OpenCv.py @@ -1,6 +1,7 @@ # built-in dependencies import os -from typing import Any, List +from typing import Any, List, Union +import logging # 3rd party dependencies import cv2 @@ -9,6 +10,7 @@ import numpy as np #project dependencies from deepface.models.Detector import Detector, FacialAreaRegion +logger = logging.getLogger(__name__) class OpenCvClient(Detector): """ @@ -17,6 +19,7 @@ class OpenCvClient(Detector): def __init__(self): self.model = self.build_model() + self.supports_batch_detection = self._supports_batch_detection() def build_model(self): """ @@ -29,55 +32,72 @@ class OpenCvClient(Detector): detector["eye_detector"] = self.__build_cascade("haarcascade_eye") return detector - def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: + def _supports_batch_detection(self) -> bool: + logger.warning( + "Batch detection is disabled for opencv by default " + "since the results are not consistent with single image detection. " + ) + return False + + def detect_faces( + self, + img: Union[np.ndarray, List[np.ndarray]] + ) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]: """ Detect and align face with opencv Args: - img (np.ndarray): pre-loaded image as numpy array + img (Union[np.ndarray, List[np.ndarray]]): + Pre-loaded image as numpy array or a list of those Returns: - results (List[FacialAreaRegion]): A list of FacialAreaRegion objects + results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): + A list or a list of lists of FacialAreaRegion objects """ - resp = [] + if isinstance(img, np.ndarray): + imgs = [img] + elif self.supports_batch_detection: + imgs = img + else: + return [self.detect_faces(single_img) for single_img in img] - detected_face = None + batch_results = [] - faces = [] - try: - # faces = detector["face_detector"].detectMultiScale(img, 1.3, 5) - - # note that, by design, opencv's haarcascade scores are >0 but not capped at 1 - faces, _, scores = self.model["face_detector"].detectMultiScale3( - img, 1.1, 10, outputRejectLevels=True - ) - except: - pass - - if len(faces) > 0: - for (x, y, w, h), confidence in zip(faces, scores): - detected_face = img[int(y) : int(y + h), int(x) : int(x + w)] - left_eye, right_eye = self.find_eyes(img=detected_face) - - # eyes found in the detected face instead image itself - # detected face's coordinates should be added - if left_eye is not None: - left_eye = (int(x + left_eye[0]), int(y + left_eye[1])) - if right_eye is not None: - right_eye = (int(x + right_eye[0]), int(y + right_eye[1])) - - facial_area = FacialAreaRegion( - x=x, - y=y, - w=w, - h=h, - left_eye=left_eye, - right_eye=right_eye, - confidence=(100 - confidence) / 100, + for single_img in imgs: + resp = [] + detected_face = None + faces = [] + try: + faces, _, scores = self.model["face_detector"].detectMultiScale3( + single_img, 1.1, 10, outputRejectLevels=True ) - resp.append(facial_area) + except: + pass - return resp + if len(faces) > 0: + for (x, y, w, h), confidence in zip(faces, scores): + detected_face = single_img[int(y):int(y + h), int(x):int(x + w)] + left_eye, right_eye = self.find_eyes(img=detected_face) + + if left_eye is not None: + left_eye = (int(x + left_eye[0]), int(y + left_eye[1])) + if right_eye is not None: + right_eye = (int(x + right_eye[0]), int(y + right_eye[1])) + + facial_area = FacialAreaRegion( + x=x, + y=y, + w=w, + h=h, + left_eye=left_eye, + right_eye=right_eye, + confidence=(100 - confidence) / 100, + ) + resp.append(facial_area) + + batch_results.append(resp) + + return batch_results if len(batch_results) > 1 else batch_results[0] def find_eyes(self, img: np.ndarray) -> tuple: """ diff --git a/deepface/models/face_detection/RetinaFace.py b/deepface/models/face_detection/RetinaFace.py index c687322..2722e01 100644 --- a/deepface/models/face_detection/RetinaFace.py +++ b/deepface/models/face_detection/RetinaFace.py @@ -1,5 +1,5 @@ # built-in dependencies -from typing import List +from typing import List, Union # 3rd party dependencies import numpy as np @@ -13,64 +13,75 @@ class RetinaFaceClient(Detector): def __init__(self): self.model = rf.build_model() - def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: + def detect_faces( + self, + img: Union[np.ndarray, List[np.ndarray]] + ) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]: """ - Detect and align face with retinaface + Detect and align faces with retinaface in a batch of images Args: - img (np.ndarray): pre-loaded image as numpy array + img (Union[np.ndarray, List[np.ndarray]]): + Pre-loaded image as numpy array or a list of those Returns: - results (List[FacialAreaRegion]): A list of FacialAreaRegion objects + results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): + A list or a list of lists of FacialAreaRegion objects """ - resp = [] + is_batched_input = isinstance(img, list) + if not is_batched_input: + imgs = [img] + else: + imgs = img - obj = rf.detect_faces(img, model=self.model, threshold=0.9) + batch_results = [] - if not isinstance(obj, dict): - return resp + for single_img in imgs: + resp = [] + obj = rf.detect_faces(single_img, model=self.model, threshold=0.9) - for face_idx in obj.keys(): - identity = obj[face_idx] - detection = identity["facial_area"] + if isinstance(obj, dict): + for face_idx in obj.keys(): + identity = obj[face_idx] + detection = identity["facial_area"] - y = detection[1] - h = detection[3] - y - x = detection[0] - w = detection[2] - x + y = detection[1] + h = detection[3] - y + x = detection[0] + w = detection[2] - x - # retinaface sets left and right eyes with respect to the person - left_eye = identity["landmarks"]["left_eye"] - right_eye = identity["landmarks"]["right_eye"] - nose = identity["landmarks"].get("nose") - mouth_right = identity["landmarks"].get("mouth_right") - mouth_left = identity["landmarks"].get("mouth_left") + left_eye = tuple(int(i) for i in identity["landmarks"]["left_eye"]) + right_eye = tuple(int(i) for i in identity["landmarks"]["right_eye"]) + nose = identity["landmarks"].get("nose") + mouth_right = identity["landmarks"].get("mouth_right") + mouth_left = identity["landmarks"].get("mouth_left") - # eyes are list of float, need to cast them tuple of int - left_eye = tuple(int(i) for i in left_eye) - right_eye = tuple(int(i) for i in right_eye) - if nose is not None: - nose = tuple(int(i) for i in nose) - if mouth_right is not None: - mouth_right = tuple(int(i) for i in mouth_right) - if mouth_left is not None: - mouth_left = tuple(int(i) for i in mouth_left) + if nose is not None: + nose = tuple(int(i) for i in nose) + if mouth_right is not None: + mouth_right = tuple(int(i) for i in mouth_right) + if mouth_left is not None: + mouth_left = tuple(int(i) for i in mouth_left) - confidence = identity["score"] + confidence = identity["score"] - facial_area = FacialAreaRegion( - x=x, - y=y, - w=w, - h=h, - left_eye=left_eye, - right_eye=right_eye, - confidence=confidence, - nose=nose, - mouth_left=mouth_left, - mouth_right=mouth_right, - ) + facial_area = FacialAreaRegion( + x=x, + y=y, + w=w, + h=h, + left_eye=left_eye, + right_eye=right_eye, + confidence=confidence, + nose=nose, + mouth_left=mouth_left, + mouth_right=mouth_right, + ) - resp.append(facial_area) + resp.append(facial_area) - return resp + batch_results.append(resp) + + if not is_batched_input: + return batch_results[0] + return batch_results diff --git a/deepface/models/face_detection/Ssd.py b/deepface/models/face_detection/Ssd.py index 449144f..a620f96 100644 --- a/deepface/models/face_detection/Ssd.py +++ b/deepface/models/face_detection/Ssd.py @@ -1,5 +1,5 @@ # built-in dependencies -from typing import List +from typing import List, Union from enum import IntEnum # 3rd party dependencies @@ -54,28 +54,48 @@ class SsdClient(Detector): return {"face_detector": face_detector, "opencv_module": OpenCv.OpenCvClient()} - def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: + def detect_faces( + self, + img: Union[np.ndarray, List[np.ndarray]] + ) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]: """ - Detect and align face with ssd + Detect and align faces with ssd in a batch of images Args: - img (np.ndarray): pre-loaded image as numpy array + img (Union[np.ndarray, List[np.ndarray]]): + Pre-loaded image as numpy array or a list of those + + Returns: + results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): + A list or a list of lists of FacialAreaRegion objects + """ + is_batched_input = isinstance(img, list) + if not is_batched_input: + img = [img] + results = [self._process_single_image(single_img) for single_img in img] + if not is_batched_input: + return results[0] + return results + + def _process_single_image(self, single_img: np.ndarray) -> List[FacialAreaRegion]: + """ + Helper function to detect faces in a single image. + + Args: + single_img (np.ndarray): Pre-loaded image as numpy array Returns: results (List[FacialAreaRegion]): A list of FacialAreaRegion objects """ - # Because cv2.dnn.blobFromImage expects CV_8U (8-bit unsigned integer) values - if img.dtype != np.uint8: - img = img.astype(np.uint8) + if single_img.dtype != np.uint8: + single_img = single_img.astype(np.uint8) opencv_module: OpenCv.OpenCvClient = self.model["opencv_module"] target_size = (300, 300) - - original_size = img.shape - - current_img = cv2.resize(img, target_size) + original_size = single_img.shape + current_img = cv2.resize(single_img, target_size) aspect_ratio_x = original_size[1] / target_size[1] aspect_ratio_y = original_size[0] / target_size[0] @@ -112,7 +132,7 @@ class SsdClient(Detector): for face in faces: confidence = float(face[ssd_labels.confidence]) x, y, w, h = map(int, face[margins]) - detected_face = img[y : y + h, x : x + w] + detected_face = single_img[y : y + h, x : x + w] left_eye, right_eye = opencv_module.find_eyes(detected_face) @@ -133,4 +153,5 @@ class SsdClient(Detector): confidence=confidence, ) resp.append(facial_area) + return resp diff --git a/deepface/models/face_detection/Yolo.py b/deepface/models/face_detection/Yolo.py index 233f088..cbb8879 100644 --- a/deepface/models/face_detection/Yolo.py +++ b/deepface/models/face_detection/Yolo.py @@ -1,6 +1,6 @@ # built-in dependencies import os -from typing import List, Any +from typing import List, Any, Union from enum import Enum # 3rd party dependencies @@ -62,64 +62,89 @@ class YoloDetectorClient(Detector): # Return face_detector return YOLO(weight_file) - def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: + def detect_faces( + self, + img: Union[np.ndarray, List[np.ndarray]] + ) -> Union[List[List[FacialAreaRegion]], List[FacialAreaRegion]]: """ - Detect and align face with yolo + Detect and align faces in a batch of images with yolo Args: - img (np.ndarray): pre-loaded image as numpy array + img (Union[np.ndarray, List[np.ndarray]]): + Pre-loaded image as numpy array or a list of those Returns: - results (List[FacialAreaRegion]): A list of FacialAreaRegion objects + results (Union[List[List[FacialAreaRegion]], List[FacialAreaRegion]]): + A list of lists of FacialAreaRegion objects + for each image or a list of FacialAreaRegion objects """ - resp = [] + if not isinstance(img, list): + img = [img] - # Detect faces - results = self.model.predict( + all_results = [] + + # Detect faces for all images + results_list = self.model.predict( img, verbose=False, show=False, conf=float(os.getenv("YOLO_MIN_DETECTION_CONFIDENCE", "0.25")), - )[0] + ) - # For each face, extract the bounding box, the landmarks and confidence - for result in results: + # Iterate over each image's results + for results in results_list: + resp = [] - if result.boxes is None: - continue + # For each face, extract the bounding box, the landmarks and confidence + for result in results: - # Extract the bounding box and the confidence - x, y, w, h = result.boxes.xywh.tolist()[0] - confidence = result.boxes.conf.tolist()[0] + if result.boxes is None: + continue - right_eye = None - left_eye = None + # Extract the bounding box and the confidence + x, y, w, h = result.boxes.xywh.tolist()[0] + confidence = result.boxes.conf.tolist()[0] - # yolo-facev8 is detecting eyes through keypoints, - # while for v11 keypoints are always None - if result.keypoints is not None: - # right_eye_conf = result.keypoints.conf[0][0] - # left_eye_conf = result.keypoints.conf[0][1] - right_eye = result.keypoints.xy[0][0].tolist() - left_eye = result.keypoints.xy[0][1].tolist() + right_eye = None + left_eye = None - # eyes are list of float, need to cast them tuple of int - left_eye = tuple(int(i) for i in left_eye) - right_eye = tuple(int(i) for i in right_eye) + # yolo-facev8 is detecting eyes through keypoints, + # while for v11 keypoints are always None + if result.keypoints is not None: + # right_eye_conf = result.keypoints.conf[0][0] + # left_eye_conf = result.keypoints.conf[0][1] + right_eye = result.keypoints.xy[0][0].tolist() + left_eye = result.keypoints.xy[0][1].tolist() - x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h) - facial_area = FacialAreaRegion( - x=x, - y=y, - w=w, - h=h, - left_eye=left_eye, - right_eye=right_eye, - confidence=confidence, - ) - resp.append(facial_area) + # eyes are list of float, need to cast them tuple of int + # Ensure eyes are tuples of exactly two integers or None + left_eye = ( + tuple(map(int, left_eye[:2])) + if left_eye and len(left_eye) == 2 + else None + ) + right_eye = ( + tuple(map(int, right_eye[:2])) + if right_eye and len(right_eye) == 2 + else None + ) + x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h) + facial_area = FacialAreaRegion( + x=x, + y=y, + w=w, + h=h, + left_eye=left_eye, + right_eye=right_eye, + confidence=confidence, + ) + resp.append(facial_area) - return resp + all_results.append(resp) + + if len(all_results) == 1: + return all_results[0] + return all_results class YoloDetectorClientV8n(YoloDetectorClient): diff --git a/deepface/models/face_detection/YuNet.py b/deepface/models/face_detection/YuNet.py index 9075927..93e65a8 100644 --- a/deepface/models/face_detection/YuNet.py +++ b/deepface/models/face_detection/YuNet.py @@ -1,6 +1,6 @@ # built-in dependencies import os -from typing import Any, List +from typing import Any, List, Union # 3rd party dependencies import cv2 @@ -57,10 +57,28 @@ class YuNetClient(Detector): ) from err return face_detector - def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: + def detect_faces(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]: """ Detect and align face with yunet + Args: + img (Union[np.ndarray, List[np.ndarray]]): pre-loaded image as numpy array or a list of those + + Returns: + results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): A list or a list of lists of FacialAreaRegion objects + """ + is_batched_input = isinstance(img, list) + if not is_batched_input: + img = [img] + results = [self._process_single_image(single_img) for single_img in img] + if not is_batched_input: + return results[0] + return results + + def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]: + """ + Helper function to detect faces in a single image. + Args: img (np.ndarray): pre-loaded image as numpy array diff --git a/deepface/modules/detection.py b/deepface/modules/detection.py index c31a026..202d2ac 100644 --- a/deepface/modules/detection.py +++ b/deepface/modules/detection.py @@ -1,5 +1,5 @@ # built-in dependencies -from typing import Any, Dict, IO, List, Tuple, Union, Optional +from typing import Any, Dict, IO, List, Tuple, Union, Optional, Sequence # 3rd part dependencies from heapq import nlargest @@ -19,7 +19,7 @@ logger = Logger() def extract_faces( - img_path: Union[str, np.ndarray, IO[bytes]], + img_path: Union[Sequence[Union[str, np.ndarray, IO[bytes]]], str, np.ndarray, IO[bytes]], detector_backend: str = "opencv", enforce_detection: bool = True, align: bool = True, @@ -29,13 +29,14 @@ def extract_faces( normalize_face: bool = True, anti_spoofing: bool = False, max_faces: Optional[int] = None, -) -> List[Dict[str, Any]]: +) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """ - Extract faces from a given image + Extract faces from a given image or list of images Args: - img_path (str or np.ndarray or IO[bytes]): Path to the first image. Accepts exact image path - as a string, numpy array (BGR), a file object that supports at least `.read` and is + img_paths (List[str or np.ndarray or IO[bytes]] or str or np.ndarray or IO[bytes]): + Path(s) to the image(s) as a string, + numpy array (BGR), a file object that supports at least `.read` and is opened in binary mode, or base64 encoded images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', @@ -61,7 +62,8 @@ def extract_faces( anti_spoofing (boolean): Flag to enable anti spoofing (default is False). Returns: - results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains: + results (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]): + A list or list of lists of dictionaries, where each dictionary contains: - "face" (np.ndarray): The detected face as a NumPy array in RGB format. @@ -80,135 +82,158 @@ def extract_faces( just available in the result only if anti_spoofing is set to True in input arguments. """ - resp_objs = [] - - # img might be path, base64 or numpy array. Convert it to numpy whatever it is. - img, img_name = image_utils.load_image(img_path) - - if img is None: - raise ValueError(f"Exception while loading {img_name}") - - height, width, _ = img.shape - - base_region = FacialAreaRegion(x=0, y=0, w=width, h=height, confidence=0) - - if detector_backend == "skip": - face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)] + batched_input = ( + ( + isinstance(img_path, np.ndarray) and + img_path.ndim == 4 + ) or isinstance(img_path, list) + ) + if not batched_input: + imgs_path = [img_path] + elif isinstance(img_path, np.ndarray): + imgs_path = [img_path[i] for i in range(img_path.shape[0])] else: - face_objs = detect_faces( - detector_backend=detector_backend, - img=img, - align=align, - expand_percentage=expand_percentage, - max_faces=max_faces, - ) + imgs_path = img_path - # in case of no face found - if len(face_objs) == 0 and enforce_detection is True: - if img_name is not None: - raise ValueError( - f"Face could not be detected in {img_name}." - "Please confirm that the picture is a face photo " - "or consider to set enforce_detection param to False." - ) - else: - raise ValueError( - "Face could not be detected. Please confirm that the picture is a face photo " - "or consider to set enforce_detection param to False." - ) + all_images = [] + img_names = [] - if len(face_objs) == 0 and enforce_detection is False: - face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)] + for single_img_path in imgs_path: + # img might be path, base64 or numpy array. Convert it to numpy whatever it is. + img, img_name = image_utils.load_image(single_img_path) - for face_obj in face_objs: - current_img = face_obj.img - current_region = face_obj.facial_area + if img is None: + raise ValueError(f"Exception while loading {img_name}") - if current_img.shape[0] == 0 or current_img.shape[1] == 0: - continue + all_images.append(img) + img_names.append(img_name) - if grayscale is True: - logger.warn("Parameter grayscale is deprecated. Use color_face instead.") - current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY) - else: - if color_face == "rgb": - current_img = current_img[:, :, ::-1] - elif color_face == "bgr": - pass # image is in BGR - elif color_face == "gray": + # Run detect_faces for all images at once + all_face_objs = detect_faces( + detector_backend=detector_backend, + img=all_images, + align=align, + expand_percentage=expand_percentage, + max_faces=max_faces, + ) + + all_resp_objs = [] + + for img, img_name, face_objs in zip(all_images, img_names, all_face_objs): + height, width, _ = img.shape + + if len(face_objs) == 0 and enforce_detection is True: + if img_name is not None: + raise ValueError( + f"Face could not be detected in {img_name}." + "Please confirm that the picture is a face photo " + "or consider to set enforce_detection param to False." + ) + else: + raise ValueError( + "Face could not be detected. Please confirm that the picture is a face photo " + "or consider to set enforce_detection param to False." + ) + + if len(face_objs) == 0 and enforce_detection is False: + base_region = FacialAreaRegion(x=0, y=0, w=width, h=height, confidence=0) + face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)] + + img_resp_objs = [] + for face_obj in face_objs: + current_img = face_obj.img + current_region = face_obj.facial_area + + if current_img.shape[0] == 0 or current_img.shape[1] == 0: + continue + + if grayscale is True: + logger.warn("Parameter grayscale is deprecated. Use color_face instead.") current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY) else: - raise ValueError(f"The color_face can be rgb, bgr or gray, but it is {color_face}.") + if color_face == "rgb": + current_img = current_img[:, :, ::-1] + elif color_face == "bgr": + pass # image is in BGR + elif color_face == "gray": + current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY) + else: + raise ValueError( + f"The color_face can be rgb, bgr or gray, " + f"but it is {color_face}." + ) - if normalize_face: - current_img = current_img / 255 # normalize input in [0, 1] + if normalize_face: + current_img = current_img / 255 # normalize input in [0, 1] - # cast to int for flask, and do final checks for borders - x = max(0, int(current_region.x)) - y = max(0, int(current_region.y)) - w = min(width - x - 1, int(current_region.w)) - h = min(height - y - 1, int(current_region.h)) + # cast to int for flask, and do final checks for borders + x = max(0, int(current_region.x)) + y = max(0, int(current_region.y)) + w = min(width - x - 1, int(current_region.w)) + h = min(height - y - 1, int(current_region.h)) - facial_area = { - "x": x, - "y": y, - "w": w, - "h": h, - "left_eye": current_region.left_eye, - "right_eye": current_region.right_eye, - } + facial_area = { + "x": x, + "y": y, + "w": w, + "h": h, + "left_eye": current_region.left_eye, + "right_eye": current_region.right_eye, + } - # optional nose, mouth_left and mouth_right fields are coming just for retinaface - if current_region.nose is not None: - facial_area["nose"] = current_region.nose - if current_region.mouth_left is not None: - facial_area["mouth_left"] = current_region.mouth_left - if current_region.mouth_right is not None: - facial_area["mouth_right"] = current_region.mouth_right + # optional nose, mouth_left and mouth_right fields are coming just for retinaface + if current_region.nose is not None: + facial_area["nose"] = current_region.nose + if current_region.mouth_left is not None: + facial_area["mouth_left"] = current_region.mouth_left + if current_region.mouth_right is not None: + facial_area["mouth_right"] = current_region.mouth_right - resp_obj = { - "face": current_img, - "facial_area": facial_area, - "confidence": round(float(current_region.confidence or 0), 2), - } + resp_obj = { + "face": current_img, + "facial_area": facial_area, + "confidence": round(float(current_region.confidence or 0), 2), + } - if anti_spoofing is True: - antispoof_model = modeling.build_model(task="spoofing", model_name="Fasnet") - is_real, antispoof_score = antispoof_model.analyze(img=img, facial_area=(x, y, w, h)) - resp_obj["is_real"] = is_real - resp_obj["antispoof_score"] = antispoof_score + if anti_spoofing is True: + antispoof_model = modeling.build_model(task="spoofing", model_name="Fasnet") + is_real, antispoof_score = antispoof_model.analyze( + img=img, + facial_area=(x, y, w, h) + ) + resp_obj["is_real"] = is_real + resp_obj["antispoof_score"] = antispoof_score - resp_objs.append(resp_obj) + img_resp_objs.append(resp_obj) - if len(resp_objs) == 0 and enforce_detection == True: - raise ValueError( - f"Exception while extracting faces from {img_name}." - "Consider to set enforce_detection arg to False." - ) + all_resp_objs.append(img_resp_objs) - return resp_objs + if not batched_input: + return all_resp_objs[0] + return all_resp_objs def detect_faces( detector_backend: str, - img: np.ndarray, + img: Union[np.ndarray, List[np.ndarray]], align: bool = True, expand_percentage: int = 0, max_faces: Optional[int] = None, -) -> List[DetectedFace]: +) -> Union[List[List[DetectedFace]], List[DetectedFace]]: """ - Detect face(s) from a given image + Detect face(s) from a given image or list of images Args: detector_backend (str): detector name - img (np.ndarray): pre-loaded image + img (np.ndarray or List[np.ndarray]): pre-loaded image or list of images align (bool): enable or disable alignment after detection expand_percentage (int): expand detected facial area with a percentage (default is 0). Returns: - results (List[DetectedFace]): A list of DetectedFace objects + results (Union[List[List[DetectedFace]], List[DetectedFace]]): + A list of lists of DetectedFace objects or a list of DetectedFace objects where each object contains: - img (np.ndarray): The detected face as a NumPy array. @@ -219,53 +244,113 @@ def detect_faces( - confidence (float): The confidence score associated with the detected face. """ - height, width, _ = img.shape + batched_input = ( + ( + isinstance(img, np.ndarray) and + img.ndim == 4 + ) or isinstance(img, list) + ) + if not batched_input: + imgs = [img] + elif isinstance(img, np.ndarray): + imgs = [img[i] for i in range(img.shape[0])] + else: + imgs = img + + if detector_backend == "skip": + all_face_objs = [ + [ + DetectedFace( + img=single_img, + facial_area=FacialAreaRegion( + x=0, y=0, w=single_img.shape[1], h=single_img.shape[0] + ), + confidence=0, + ) + ] + for single_img in imgs + ] + if not batched_input: + all_face_objs = all_face_objs[0] + return all_face_objs + face_detector: Detector = modeling.build_model( task="face_detector", model_name=detector_backend ) - # validate expand percentage score - if expand_percentage < 0: - logger.warn( - f"Expand percentage cannot be negative but you set it to {expand_percentage}." - "Overwritten it to 0." - ) - expand_percentage = 0 + preprocessed_images = [] + width_borders = [] + height_borders = [] + for single_img in imgs: + height, width, _ = single_img.shape - # If faces are close to the upper boundary, alignment move them outside - # Add a black border around an image to avoid this. - height_border = int(0.5 * height) - width_border = int(0.5 * width) - if align is True: - img = cv2.copyMakeBorder( - img, - height_border, - height_border, - width_border, - width_border, - cv2.BORDER_CONSTANT, - value=[0, 0, 0], # Color of the border (black) - ) + # validate expand percentage score + if expand_percentage < 0: + logger.warn( + f"Expand percentage cannot be negative but you set it to {expand_percentage}." + "Overwritten it to 0." + ) + expand_percentage = 0 - # find facial areas of given image - facial_areas = face_detector.detect_faces(img) + # If faces are close to the upper boundary, alignment move them outside + # Add a black border around an image to avoid this. + height_border = int(0.5 * height) + width_border = int(0.5 * width) + if align is True: + single_img = cv2.copyMakeBorder( + single_img, + height_border, + height_border, + width_border, + width_border, + cv2.BORDER_CONSTANT, + value=[0, 0, 0], # Color of the border (black) + ) - if max_faces is not None and max_faces < len(facial_areas): - facial_areas = nlargest( - max_faces, facial_areas, key=lambda facial_area: facial_area.w * facial_area.h - ) + preprocessed_images.append(single_img) + width_borders.append(width_border) + height_borders.append(height_border) - return [ - extract_face( - facial_area=facial_area, - img=img, - align=align, - expand_percentage=expand_percentage, - width_border=width_border, - height_border=height_border, - ) - for facial_area in facial_areas - ] + # Detect faces in all preprocessed images + all_facial_areas = face_detector.detect_faces(preprocessed_images) + + all_detected_faces = [] + for ( + single_img, + facial_areas, + width_border, + height_border + ) in zip( + preprocessed_images, + all_facial_areas, + width_borders, + height_borders + ): + if not isinstance(facial_areas, list): + facial_areas = [facial_areas] + + if max_faces is not None and max_faces < len(facial_areas): + facial_areas = nlargest( + max_faces, facial_areas, key=lambda facial_area: facial_area.w * facial_area.h + ) + + detected_faces = [ + extract_face( + facial_area=facial_area, + img=single_img, + align=align, + expand_percentage=expand_percentage, + width_border=width_border, + height_border=height_border, + ) + for facial_area in facial_areas if isinstance(facial_area, FacialAreaRegion) + ] + + all_detected_faces.append(detected_faces) + + if not batched_input: + return all_detected_faces[0] + return all_detected_faces def extract_face( diff --git a/tests/test_extract_faces.py b/tests/test_extract_faces.py index 262d22d..1ecaafa 100644 --- a/tests/test_extract_faces.py +++ b/tests/test_extract_faces.py @@ -23,6 +23,10 @@ def test_different_detectors(): for detector in detectors: img_objs = DeepFace.extract_faces(img_path=img_path, detector_backend=detector) + + # Check return type for non-batch input + assert isinstance(img_objs, list) and all(isinstance(obj, dict) for obj in img_objs) + for img_obj in img_objs: assert "face" in img_obj.keys() assert "facial_area" in img_obj.keys() @@ -79,6 +83,169 @@ def test_different_detectors(): logger.info(f"✅ extract_faces for {detector} backend test is done") +@pytest.mark.parametrize("detector_backend", [ + # "opencv", + "ssd", + "mtcnn", + "retinaface", + "yunet", + "centerface", + # optional + # "yolov11s", + # "mediapipe", + # "dlib", +]) +def test_batch_extract_faces(detector_backend): + # Relative tolerance for comparing floating-point values + rtol = 0.03 + img_paths = [ + "dataset/img2.jpg", + "dataset/img3.jpg", + "dataset/img11.jpg", + "dataset/couple.jpg" + ] + expected_num_faces = [1, 1, 1, 2] + + # Extract faces one by one + imgs_objs_individual = [ + DeepFace.extract_faces( + img_path=img_path, + detector_backend=detector_backend, + align=True, + ) for img_path in img_paths + ] + + # Check that individual extraction returns a list of faces + for img_objs_individual in imgs_objs_individual: + assert isinstance(img_objs_individual, list) + assert all(isinstance(face, dict) for face in img_objs_individual) + + # Check that the individual extraction results match the expected number of faces + for img_objs_individual, expected_faces in zip(imgs_objs_individual, expected_num_faces): + assert len(img_objs_individual) == expected_faces + + # Extract faces in batch + imgs_objs_batch = DeepFace.extract_faces( + img_path=img_paths, + detector_backend=detector_backend, + align=True, + ) + + # Check that the batch extraction returned the expected number of face lists + assert len(imgs_objs_batch) == len(img_paths) + + # Check that each face list has the expected number of faces + for i, expected_faces in enumerate(expected_num_faces): + assert len(imgs_objs_batch[i]) == expected_faces + + # Check that the individual extraction results match the batch extraction results + for img_objs_individual, img_objs_batch in zip(imgs_objs_individual, imgs_objs_batch): + assert len(img_objs_batch) == len(img_objs_individual), ( + "Batch and individual extraction results should have the same number of detected faces" + ) + for img_obj_individual, img_obj_batch in zip(img_objs_individual, img_objs_batch): + for key in img_obj_individual["facial_area"]: + if isinstance(img_obj_individual["facial_area"][key], tuple): + for ind_val, batch_val in zip( + img_obj_individual["facial_area"][key], + img_obj_batch["facial_area"][key] + ): + # Ensure the difference between individual and batch values + # is within rtol% of the individual value + assert abs(ind_val - batch_val) <= rtol * ind_val + elif ( + isinstance(img_obj_individual["facial_area"][key], int) or + isinstance(img_obj_individual["facial_area"][key], float) + ): + # Ensure the difference between individual and batch values + # is within rtol% of the individual value + assert abs( + img_obj_individual["facial_area"][key] - + img_obj_batch["facial_area"][key] + ) <= rtol * img_obj_individual["facial_area"][key] + # Ensure the confidence difference is within rtol% of the individual confidence + assert abs( + img_obj_individual["confidence"] - + img_obj_batch["confidence"] + ) <= rtol * img_obj_individual["confidence"] + + +@pytest.mark.parametrize("detector_backend", [ + "opencv", + "ssd", + "mtcnn", + "retinaface", + "yunet", + # "centerface", + # optional + # "yolov11s", + # "mediapipe", + # "dlib", +]) +def test_batch_extract_faces_with_nparray(detector_backend): + img_paths = [ + "dataset/img2.jpg", + "dataset/img3.jpg", + "dataset/img11.jpg", + "dataset/couple.jpg" + ] + imgs = [ + cv2.resize(image_utils.load_image(img_path)[0], (1920, 1080)) + for img_path in img_paths + ] + expected_num_faces = [1, 1, 1, 2] + + # load images as numpy arrays + imgs_batch = np.stack(imgs, axis=0) + + # extract faces in batch of numpy arrays + imgs_objs_batch = DeepFace.extract_faces( + img_path=imgs_batch, + detector_backend=detector_backend, + align=True, + enforce_detection=False, + ) + + # Check return type for batch input + assert ( + isinstance(imgs_objs_batch, list) and + all( + isinstance(obj, list) and + all(isinstance(face, dict) for face in obj) + for obj in imgs_objs_batch + ) + ) + + # Check that the batch extraction returned the expected number of face lists + assert len(imgs_objs_batch) == len(img_paths) + for img_objs_batch, img_expected_num_faces in zip(imgs_objs_batch, expected_num_faces): + assert len(img_objs_batch) == img_expected_num_faces + + # extract faces in batch of paths + imgs_objs_batch_paths = DeepFace.extract_faces( + img_path=imgs, + detector_backend=detector_backend, + align=True, + enforce_detection=False, + ) + + # compare results + for img_objs_batch, img_objs_batch_paths in zip(imgs_objs_batch, imgs_objs_batch_paths): + assert len(img_objs_batch) == len(img_objs_batch_paths), ( + "Batch and individual extraction results should have the same number of detected faces" + ) + + +def test_batch_extract_faces_single_image(): + img_path = "dataset/couple.jpg" + imgs_objs_batch = DeepFace.extract_faces( + img_path=[img_path], + align=True, + ) + assert len(imgs_objs_batch) == 1 and isinstance(imgs_objs_batch[0], list) + assert [isinstance(obj, dict) for obj in imgs_objs_batch[0]] + + def test_backends_for_enforced_detection_with_non_facial_inputs(): black_img = np.zeros([224, 224, 3]) for detector in detectors: