Merge 3111a2895a0e2561403ea193dd2f75f069160d57 into df1b6ab6fe431f126afe7ff3baa7fe394a967f0a

This commit is contained in:
galthran-wq 2025-05-17 18:30:50 -07:00 committed by GitHub
commit ab84683bbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 801 additions and 326 deletions

View File

@ -521,7 +521,7 @@ def stream(
def extract_faces( def extract_faces(
img_path: Union[str, np.ndarray, IO[bytes]], img_path: Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]],
detector_backend: str = "opencv", detector_backend: str = "opencv",
enforce_detection: bool = True, enforce_detection: bool = True,
align: bool = True, align: bool = True,
@ -530,14 +530,14 @@ def extract_faces(
color_face: str = "rgb", color_face: str = "rgb",
normalize_face: bool = True, normalize_face: bool = True,
anti_spoofing: bool = False, anti_spoofing: bool = False,
) -> List[Dict[str, Any]]: ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
""" """
Extract faces from a given image Extract faces from a given image or sequence of images.
Args: Args:
img_path (str or np.ndarray or IO[bytes]): Path to the first image. Accepts exact image path img_path (Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]]):
as a string, numpy array (BGR), a file object that supports at least `.read` and is Path(s) to the image(s). Accepts a string path, a numpy array (BGR), a file object
opened in binary mode, or base64 encoded images. that supports at least `.read` and is opened in binary mode, or base64 encoded images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
@ -562,7 +562,8 @@ def extract_faces(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False). anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns: Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains: results (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]):
A list or a list of lists of dictionaries, where each dictionary contains:
- "face" (np.ndarray): The detected face as a NumPy array. - "face" (np.ndarray): The detected face as a NumPy array.

View File

@ -1,4 +1,4 @@
from typing import List, Tuple, Optional from typing import List, Tuple, Optional, Union
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@ -9,15 +9,20 @@ import numpy as np
# pylint: disable=unnecessary-pass, too-few-public-methods, too-many-instance-attributes # pylint: disable=unnecessary-pass, too-few-public-methods, too-many-instance-attributes
class Detector(ABC): class Detector(ABC):
@abstractmethod @abstractmethod
def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]: def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List["FacialAreaRegion"], List[List["FacialAreaRegion"]]]:
""" """
Interface for detect and align face Interface for detect and align faces in a batch of images
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (Union[np.ndarray, List[np.ndarray]]):
Pre-loaded image as numpy array or a list of those
Returns: Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects results (Union[List[List[FacialAreaRegion]], List[FacialAreaRegion]]):
A list or a list of lists of FacialAreaRegion objects
where each object contains: where each object contains:
- facial_area (FacialAreaRegion): The facial area region represented - facial_area (FacialAreaRegion): The facial area region represented
@ -28,6 +33,7 @@ class Detector(ABC):
pass pass
# pylint: disable=unnecessary-pass, too-few-public-methods, too-many-instance-attributes
@dataclass @dataclass
class FacialAreaRegion: class FacialAreaRegion:
""" """

View File

@ -1,6 +1,6 @@
# built-in dependencies # built-in dependencies
import os import os
from typing import List from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -34,12 +34,35 @@ class CenterFaceClient(Detector):
return CenterFace(weight_path=weights_path) return CenterFace(weight_path=weights_path)
def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]: def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]],
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
""" """
Detect and align face with CenterFace Detect and align face with CenterFace
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (Union[np.ndarray, List[np.ndarray]]):
pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, single_img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args:
single_img (np.ndarray): pre-loaded image as numpy array
Returns: Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
@ -53,7 +76,7 @@ class CenterFaceClient(Detector):
# img, img.shape[0], img.shape[1], threshold=threshold # img, img.shape[0], img.shape[1], threshold=threshold
# ) # )
detections, landmarks = self.build_model().forward( detections, landmarks = self.build_model().forward(
img, img.shape[0], img.shape[1], threshold=threshold single_img, single_img.shape[0], single_img.shape[1], threshold=threshold
) )
for i, detection in enumerate(detections): for i, detection in enumerate(detections):

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import List from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -47,10 +47,33 @@ class DlibClient(Detector):
detector["sp"] = sp detector["sp"] = sp
return detector return detector
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]],
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
""" """
Detect and align face with dlib Detect and align face with dlib
Args:
img (Union[np.ndarray, List[np.ndarray]]):
pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array

View File

@ -17,10 +17,33 @@ class FastMtCnnClient(Detector):
def __init__(self): def __init__(self):
self.model = self.build_model() self.model = self.build_model()
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]],
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
""" """
Detect and align face with mtcnn Detect and align face with mtcnn
Args:
img (Union[np.ndarray, List[np.ndarray]]):
pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array

View File

@ -1,6 +1,6 @@
# built-in dependencies # built-in dependencies
import os import os
from typing import Any, List from typing import Any, List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -43,10 +43,33 @@ class MediaPipeClient(Detector):
) )
return face_detection return face_detection
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]],
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
""" """
Detect and align face with mediapipe Detect and align face with mediapipe
Args:
img (Union[np.ndarray, List[np.ndarray]]):
pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array

View File

@ -1,5 +1,6 @@
# built-in dependencies # built-in dependencies
from typing import List import logging
from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -8,6 +9,8 @@ from mtcnn import MTCNN
# project dependencies # project dependencies
from deepface.models.Detector import Detector, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
logger = logging.getLogger(__name__)
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class MtCnnClient(Detector): class MtCnnClient(Detector):
""" """
@ -16,45 +19,71 @@ class MtCnnClient(Detector):
def __init__(self): def __init__(self):
self.model = MTCNN() self.model = MTCNN()
self.supports_batch_detection = self._supports_batch_detection()
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
""" """
Detect and align face with mtcnn Detect and align faces with mtcnn for a list of images
Args: Args:
img (np.ndarray): pre-loaded image as numpy array imgs (Union[np.ndarray, List[np.ndarray]]):
pre-loaded image as numpy array or a list of those
Returns: Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list of FacialAreaRegion objects for a single image
or a list of lists of FacialAreaRegion objects for each image
""" """
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
resp = [] resp = []
# mtcnn expects RGB but OpenCV read BGR # mtcnn expects RGB but OpenCV read BGR
# img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_rgb = img[:, :, ::-1] img_rgb = [img[:, :, ::-1] for img in img]
detections = self.model.detect_faces(img_rgb) if self.supports_batch_detection:
detections = self.model.detect_faces(img_rgb)
else:
detections = [self.model.detect_faces(single_img) for single_img in img_rgb]
if detections is not None and len(detections) > 0: for image_detections in detections:
image_resp = []
if image_detections is not None and len(image_detections) > 0:
for current_detection in image_detections:
x, y, w, h = current_detection["box"]
confidence = current_detection["confidence"]
# mtcnn detector assigns left eye with respect to the observer
# but we are setting it with respect to the person itself
left_eye = current_detection["keypoints"]["right_eye"]
right_eye = current_detection["keypoints"]["left_eye"]
for current_detection in detections: facial_area = FacialAreaRegion(
x, y, w, h = current_detection["box"] x=x,
confidence = current_detection["confidence"] y=y,
# mtcnn detector assigns left eye with respect to the observer w=w,
# but we are setting it with respect to the person itself h=h,
left_eye = current_detection["keypoints"]["right_eye"] left_eye=left_eye,
right_eye = current_detection["keypoints"]["left_eye"] right_eye=right_eye,
confidence=confidence,
)
facial_area = FacialAreaRegion( image_resp.append(facial_area)
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
)
resp.append(facial_area) resp.append(image_resp)
if not is_batched_input:
return resp[0]
return resp return resp
def _supports_batch_detection(self) -> bool:
logger.warning(
"Batch detection is disabled for mtcnn by default "
"since the results are not consistent with single image detection. "
)
return False

View File

@ -1,6 +1,7 @@
# built-in dependencies # built-in dependencies
import os import os
from typing import Any, List from typing import Any, List, Union
import logging
# 3rd party dependencies # 3rd party dependencies
import cv2 import cv2
@ -9,6 +10,7 @@ import numpy as np
#project dependencies #project dependencies
from deepface.models.Detector import Detector, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
logger = logging.getLogger(__name__)
class OpenCvClient(Detector): class OpenCvClient(Detector):
""" """
@ -17,6 +19,7 @@ class OpenCvClient(Detector):
def __init__(self): def __init__(self):
self.model = self.build_model() self.model = self.build_model()
self.supports_batch_detection = self._supports_batch_detection()
def build_model(self): def build_model(self):
""" """
@ -29,55 +32,72 @@ class OpenCvClient(Detector):
detector["eye_detector"] = self.__build_cascade("haarcascade_eye") detector["eye_detector"] = self.__build_cascade("haarcascade_eye")
return detector return detector
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: def _supports_batch_detection(self) -> bool:
logger.warning(
"Batch detection is disabled for opencv by default "
"since the results are not consistent with single image detection. "
)
return False
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
""" """
Detect and align face with opencv Detect and align face with opencv
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (Union[np.ndarray, List[np.ndarray]]):
Pre-loaded image as numpy array or a list of those
Returns: Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
""" """
resp = [] if isinstance(img, np.ndarray):
imgs = [img]
elif self.supports_batch_detection:
imgs = img
else:
return [self.detect_faces(single_img) for single_img in img]
detected_face = None batch_results = []
faces = [] for single_img in imgs:
try: resp = []
# faces = detector["face_detector"].detectMultiScale(img, 1.3, 5) detected_face = None
faces = []
# note that, by design, opencv's haarcascade scores are >0 but not capped at 1 try:
faces, _, scores = self.model["face_detector"].detectMultiScale3( faces, _, scores = self.model["face_detector"].detectMultiScale3(
img, 1.1, 10, outputRejectLevels=True single_img, 1.1, 10, outputRejectLevels=True
)
except:
pass
if len(faces) > 0:
for (x, y, w, h), confidence in zip(faces, scores):
detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
left_eye, right_eye = self.find_eyes(img=detected_face)
# eyes found in the detected face instead image itself
# detected face's coordinates should be added
if left_eye is not None:
left_eye = (int(x + left_eye[0]), int(y + left_eye[1]))
if right_eye is not None:
right_eye = (int(x + right_eye[0]), int(y + right_eye[1]))
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=(100 - confidence) / 100,
) )
resp.append(facial_area) except:
pass
return resp if len(faces) > 0:
for (x, y, w, h), confidence in zip(faces, scores):
detected_face = single_img[int(y):int(y + h), int(x):int(x + w)]
left_eye, right_eye = self.find_eyes(img=detected_face)
if left_eye is not None:
left_eye = (int(x + left_eye[0]), int(y + left_eye[1]))
if right_eye is not None:
right_eye = (int(x + right_eye[0]), int(y + right_eye[1]))
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=(100 - confidence) / 100,
)
resp.append(facial_area)
batch_results.append(resp)
return batch_results if len(batch_results) > 1 else batch_results[0]
def find_eyes(self, img: np.ndarray) -> tuple: def find_eyes(self, img: np.ndarray) -> tuple:
""" """

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import List from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -13,64 +13,75 @@ class RetinaFaceClient(Detector):
def __init__(self): def __init__(self):
self.model = rf.build_model() self.model = rf.build_model()
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
""" """
Detect and align face with retinaface Detect and align faces with retinaface in a batch of images
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (Union[np.ndarray, List[np.ndarray]]):
Pre-loaded image as numpy array or a list of those
Returns: Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
""" """
resp = [] is_batched_input = isinstance(img, list)
if not is_batched_input:
imgs = [img]
else:
imgs = img
obj = rf.detect_faces(img, model=self.model, threshold=0.9) batch_results = []
if not isinstance(obj, dict): for single_img in imgs:
return resp resp = []
obj = rf.detect_faces(single_img, model=self.model, threshold=0.9)
for face_idx in obj.keys(): if isinstance(obj, dict):
identity = obj[face_idx] for face_idx in obj.keys():
detection = identity["facial_area"] identity = obj[face_idx]
detection = identity["facial_area"]
y = detection[1] y = detection[1]
h = detection[3] - y h = detection[3] - y
x = detection[0] x = detection[0]
w = detection[2] - x w = detection[2] - x
# retinaface sets left and right eyes with respect to the person left_eye = tuple(int(i) for i in identity["landmarks"]["left_eye"])
left_eye = identity["landmarks"]["left_eye"] right_eye = tuple(int(i) for i in identity["landmarks"]["right_eye"])
right_eye = identity["landmarks"]["right_eye"] nose = identity["landmarks"].get("nose")
nose = identity["landmarks"].get("nose") mouth_right = identity["landmarks"].get("mouth_right")
mouth_right = identity["landmarks"].get("mouth_right") mouth_left = identity["landmarks"].get("mouth_left")
mouth_left = identity["landmarks"].get("mouth_left")
# eyes are list of float, need to cast them tuple of int if nose is not None:
left_eye = tuple(int(i) for i in left_eye) nose = tuple(int(i) for i in nose)
right_eye = tuple(int(i) for i in right_eye) if mouth_right is not None:
if nose is not None: mouth_right = tuple(int(i) for i in mouth_right)
nose = tuple(int(i) for i in nose) if mouth_left is not None:
if mouth_right is not None: mouth_left = tuple(int(i) for i in mouth_left)
mouth_right = tuple(int(i) for i in mouth_right)
if mouth_left is not None:
mouth_left = tuple(int(i) for i in mouth_left)
confidence = identity["score"] confidence = identity["score"]
facial_area = FacialAreaRegion( facial_area = FacialAreaRegion(
x=x, x=x,
y=y, y=y,
w=w, w=w,
h=h, h=h,
left_eye=left_eye, left_eye=left_eye,
right_eye=right_eye, right_eye=right_eye,
confidence=confidence, confidence=confidence,
nose=nose, nose=nose,
mouth_left=mouth_left, mouth_left=mouth_left,
mouth_right=mouth_right, mouth_right=mouth_right,
) )
resp.append(facial_area) resp.append(facial_area)
return resp batch_results.append(resp)
if not is_batched_input:
return batch_results[0]
return batch_results

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import List from typing import List, Union
from enum import IntEnum from enum import IntEnum
# 3rd party dependencies # 3rd party dependencies
@ -54,28 +54,48 @@ class SsdClient(Detector):
return {"face_detector": face_detector, "opencv_module": OpenCv.OpenCvClient()} return {"face_detector": face_detector, "opencv_module": OpenCv.OpenCvClient()}
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
""" """
Detect and align face with ssd Detect and align faces with ssd in a batch of images
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (Union[np.ndarray, List[np.ndarray]]):
Pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, single_img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args:
single_img (np.ndarray): Pre-loaded image as numpy array
Returns: Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
""" """
# Because cv2.dnn.blobFromImage expects CV_8U (8-bit unsigned integer) values # Because cv2.dnn.blobFromImage expects CV_8U (8-bit unsigned integer) values
if img.dtype != np.uint8: if single_img.dtype != np.uint8:
img = img.astype(np.uint8) single_img = single_img.astype(np.uint8)
opencv_module: OpenCv.OpenCvClient = self.model["opencv_module"] opencv_module: OpenCv.OpenCvClient = self.model["opencv_module"]
target_size = (300, 300) target_size = (300, 300)
original_size = single_img.shape
original_size = img.shape current_img = cv2.resize(single_img, target_size)
current_img = cv2.resize(img, target_size)
aspect_ratio_x = original_size[1] / target_size[1] aspect_ratio_x = original_size[1] / target_size[1]
aspect_ratio_y = original_size[0] / target_size[0] aspect_ratio_y = original_size[0] / target_size[0]
@ -112,7 +132,7 @@ class SsdClient(Detector):
for face in faces: for face in faces:
confidence = float(face[ssd_labels.confidence]) confidence = float(face[ssd_labels.confidence])
x, y, w, h = map(int, face[margins]) x, y, w, h = map(int, face[margins])
detected_face = img[y : y + h, x : x + w] detected_face = single_img[y : y + h, x : x + w]
left_eye, right_eye = opencv_module.find_eyes(detected_face) left_eye, right_eye = opencv_module.find_eyes(detected_face)
@ -133,4 +153,5 @@ class SsdClient(Detector):
confidence=confidence, confidence=confidence,
) )
resp.append(facial_area) resp.append(facial_area)
return resp return resp

View File

@ -1,6 +1,6 @@
# built-in dependencies # built-in dependencies
import os import os
from typing import List, Any from typing import List, Any, Union
from enum import Enum from enum import Enum
# 3rd party dependencies # 3rd party dependencies
@ -62,64 +62,89 @@ class YoloDetectorClient(Detector):
# Return face_detector # Return face_detector
return YOLO(weight_file) return YOLO(weight_file)
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List[List[FacialAreaRegion]], List[FacialAreaRegion]]:
""" """
Detect and align face with yolo Detect and align faces in a batch of images with yolo
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (Union[np.ndarray, List[np.ndarray]]):
Pre-loaded image as numpy array or a list of those
Returns: Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects results (Union[List[List[FacialAreaRegion]], List[FacialAreaRegion]]):
A list of lists of FacialAreaRegion objects
for each image or a list of FacialAreaRegion objects
""" """
resp = [] if not isinstance(img, list):
img = [img]
# Detect faces all_results = []
results = self.model.predict(
# Detect faces for all images
results_list = self.model.predict(
img, img,
verbose=False, verbose=False,
show=False, show=False,
conf=float(os.getenv("YOLO_MIN_DETECTION_CONFIDENCE", "0.25")), conf=float(os.getenv("YOLO_MIN_DETECTION_CONFIDENCE", "0.25")),
)[0] )
# For each face, extract the bounding box, the landmarks and confidence # Iterate over each image's results
for result in results: for results in results_list:
resp = []
if result.boxes is None: # For each face, extract the bounding box, the landmarks and confidence
continue for result in results:
# Extract the bounding box and the confidence if result.boxes is None:
x, y, w, h = result.boxes.xywh.tolist()[0] continue
confidence = result.boxes.conf.tolist()[0]
right_eye = None # Extract the bounding box and the confidence
left_eye = None x, y, w, h = result.boxes.xywh.tolist()[0]
confidence = result.boxes.conf.tolist()[0]
# yolo-facev8 is detecting eyes through keypoints, right_eye = None
# while for v11 keypoints are always None left_eye = None
if result.keypoints is not None:
# right_eye_conf = result.keypoints.conf[0][0]
# left_eye_conf = result.keypoints.conf[0][1]
right_eye = result.keypoints.xy[0][0].tolist()
left_eye = result.keypoints.xy[0][1].tolist()
# eyes are list of float, need to cast them tuple of int # yolo-facev8 is detecting eyes through keypoints,
left_eye = tuple(int(i) for i in left_eye) # while for v11 keypoints are always None
right_eye = tuple(int(i) for i in right_eye) if result.keypoints is not None:
# right_eye_conf = result.keypoints.conf[0][0]
# left_eye_conf = result.keypoints.conf[0][1]
right_eye = result.keypoints.xy[0][0].tolist()
left_eye = result.keypoints.xy[0][1].tolist()
x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h) # eyes are list of float, need to cast them tuple of int
facial_area = FacialAreaRegion( # Ensure eyes are tuples of exactly two integers or None
x=x, left_eye = (
y=y, tuple(map(int, left_eye[:2]))
w=w, if left_eye and len(left_eye) == 2
h=h, else None
left_eye=left_eye, )
right_eye=right_eye, right_eye = (
confidence=confidence, tuple(map(int, right_eye[:2]))
) if right_eye and len(right_eye) == 2
resp.append(facial_area) else None
)
x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h)
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
)
resp.append(facial_area)
return resp all_results.append(resp)
if len(all_results) == 1:
return all_results[0]
return all_results
class YoloDetectorClientV8n(YoloDetectorClient): class YoloDetectorClientV8n(YoloDetectorClient):

View File

@ -1,6 +1,6 @@
# built-in dependencies # built-in dependencies
import os import os
from typing import Any, List from typing import Any, List, Union
# 3rd party dependencies # 3rd party dependencies
import cv2 import cv2
@ -57,10 +57,28 @@ class YuNetClient(Detector):
) from err ) from err
return face_detector return face_detector
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: def detect_faces(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
""" """
Detect and align face with yunet Detect and align face with yunet
Args:
img (Union[np.ndarray, List[np.ndarray]]): pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import Any, Dict, IO, List, Tuple, Union, Optional from typing import Any, Dict, IO, List, Tuple, Union, Optional, Sequence
# 3rd part dependencies # 3rd part dependencies
from heapq import nlargest from heapq import nlargest
@ -19,7 +19,7 @@ logger = Logger()
def extract_faces( def extract_faces(
img_path: Union[str, np.ndarray, IO[bytes]], img_path: Union[Sequence[Union[str, np.ndarray, IO[bytes]]], str, np.ndarray, IO[bytes]],
detector_backend: str = "opencv", detector_backend: str = "opencv",
enforce_detection: bool = True, enforce_detection: bool = True,
align: bool = True, align: bool = True,
@ -29,13 +29,14 @@ def extract_faces(
normalize_face: bool = True, normalize_face: bool = True,
anti_spoofing: bool = False, anti_spoofing: bool = False,
max_faces: Optional[int] = None, max_faces: Optional[int] = None,
) -> List[Dict[str, Any]]: ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
""" """
Extract faces from a given image Extract faces from a given image or list of images
Args: Args:
img_path (str or np.ndarray or IO[bytes]): Path to the first image. Accepts exact image path img_paths (List[str or np.ndarray or IO[bytes]] or str or np.ndarray or IO[bytes]):
as a string, numpy array (BGR), a file object that supports at least `.read` and is Path(s) to the image(s) as a string,
numpy array (BGR), a file object that supports at least `.read` and is
opened in binary mode, or base64 encoded images. opened in binary mode, or base64 encoded images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
@ -61,7 +62,8 @@ def extract_faces(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False). anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns: Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains: results (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]):
A list or list of lists of dictionaries, where each dictionary contains:
- "face" (np.ndarray): The detected face as a NumPy array in RGB format. - "face" (np.ndarray): The detected face as a NumPy array in RGB format.
@ -80,135 +82,158 @@ def extract_faces(
just available in the result only if anti_spoofing is set to True in input arguments. just available in the result only if anti_spoofing is set to True in input arguments.
""" """
resp_objs = [] batched_input = (
(
# img might be path, base64 or numpy array. Convert it to numpy whatever it is. isinstance(img_path, np.ndarray) and
img, img_name = image_utils.load_image(img_path) img_path.ndim == 4
) or isinstance(img_path, list)
if img is None: )
raise ValueError(f"Exception while loading {img_name}") if not batched_input:
imgs_path = [img_path]
height, width, _ = img.shape elif isinstance(img_path, np.ndarray):
imgs_path = [img_path[i] for i in range(img_path.shape[0])]
base_region = FacialAreaRegion(x=0, y=0, w=width, h=height, confidence=0)
if detector_backend == "skip":
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
else: else:
face_objs = detect_faces( imgs_path = img_path
detector_backend=detector_backend,
img=img,
align=align,
expand_percentage=expand_percentage,
max_faces=max_faces,
)
# in case of no face found all_images = []
if len(face_objs) == 0 and enforce_detection is True: img_names = []
if img_name is not None:
raise ValueError(
f"Face could not be detected in {img_name}."
"Please confirm that the picture is a face photo "
"or consider to set enforce_detection param to False."
)
else:
raise ValueError(
"Face could not be detected. Please confirm that the picture is a face photo "
"or consider to set enforce_detection param to False."
)
if len(face_objs) == 0 and enforce_detection is False: for single_img_path in imgs_path:
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)] # img might be path, base64 or numpy array. Convert it to numpy whatever it is.
img, img_name = image_utils.load_image(single_img_path)
for face_obj in face_objs: if img is None:
current_img = face_obj.img raise ValueError(f"Exception while loading {img_name}")
current_region = face_obj.facial_area
if current_img.shape[0] == 0 or current_img.shape[1] == 0: all_images.append(img)
continue img_names.append(img_name)
if grayscale is True: # Run detect_faces for all images at once
logger.warn("Parameter grayscale is deprecated. Use color_face instead.") all_face_objs = detect_faces(
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY) detector_backend=detector_backend,
else: img=all_images,
if color_face == "rgb": align=align,
current_img = current_img[:, :, ::-1] expand_percentage=expand_percentage,
elif color_face == "bgr": max_faces=max_faces,
pass # image is in BGR )
elif color_face == "gray":
all_resp_objs = []
for img, img_name, face_objs in zip(all_images, img_names, all_face_objs):
height, width, _ = img.shape
if len(face_objs) == 0 and enforce_detection is True:
if img_name is not None:
raise ValueError(
f"Face could not be detected in {img_name}."
"Please confirm that the picture is a face photo "
"or consider to set enforce_detection param to False."
)
else:
raise ValueError(
"Face could not be detected. Please confirm that the picture is a face photo "
"or consider to set enforce_detection param to False."
)
if len(face_objs) == 0 and enforce_detection is False:
base_region = FacialAreaRegion(x=0, y=0, w=width, h=height, confidence=0)
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
img_resp_objs = []
for face_obj in face_objs:
current_img = face_obj.img
current_region = face_obj.facial_area
if current_img.shape[0] == 0 or current_img.shape[1] == 0:
continue
if grayscale is True:
logger.warn("Parameter grayscale is deprecated. Use color_face instead.")
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY) current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
else: else:
raise ValueError(f"The color_face can be rgb, bgr or gray, but it is {color_face}.") if color_face == "rgb":
current_img = current_img[:, :, ::-1]
elif color_face == "bgr":
pass # image is in BGR
elif color_face == "gray":
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
else:
raise ValueError(
f"The color_face can be rgb, bgr or gray, "
f"but it is {color_face}."
)
if normalize_face: if normalize_face:
current_img = current_img / 255 # normalize input in [0, 1] current_img = current_img / 255 # normalize input in [0, 1]
# cast to int for flask, and do final checks for borders # cast to int for flask, and do final checks for borders
x = max(0, int(current_region.x)) x = max(0, int(current_region.x))
y = max(0, int(current_region.y)) y = max(0, int(current_region.y))
w = min(width - x - 1, int(current_region.w)) w = min(width - x - 1, int(current_region.w))
h = min(height - y - 1, int(current_region.h)) h = min(height - y - 1, int(current_region.h))
facial_area = { facial_area = {
"x": x, "x": x,
"y": y, "y": y,
"w": w, "w": w,
"h": h, "h": h,
"left_eye": current_region.left_eye, "left_eye": current_region.left_eye,
"right_eye": current_region.right_eye, "right_eye": current_region.right_eye,
} }
# optional nose, mouth_left and mouth_right fields are coming just for retinaface # optional nose, mouth_left and mouth_right fields are coming just for retinaface
if current_region.nose is not None: if current_region.nose is not None:
facial_area["nose"] = current_region.nose facial_area["nose"] = current_region.nose
if current_region.mouth_left is not None: if current_region.mouth_left is not None:
facial_area["mouth_left"] = current_region.mouth_left facial_area["mouth_left"] = current_region.mouth_left
if current_region.mouth_right is not None: if current_region.mouth_right is not None:
facial_area["mouth_right"] = current_region.mouth_right facial_area["mouth_right"] = current_region.mouth_right
resp_obj = { resp_obj = {
"face": current_img, "face": current_img,
"facial_area": facial_area, "facial_area": facial_area,
"confidence": round(float(current_region.confidence or 0), 2), "confidence": round(float(current_region.confidence or 0), 2),
} }
if anti_spoofing is True: if anti_spoofing is True:
antispoof_model = modeling.build_model(task="spoofing", model_name="Fasnet") antispoof_model = modeling.build_model(task="spoofing", model_name="Fasnet")
is_real, antispoof_score = antispoof_model.analyze(img=img, facial_area=(x, y, w, h)) is_real, antispoof_score = antispoof_model.analyze(
resp_obj["is_real"] = is_real img=img,
resp_obj["antispoof_score"] = antispoof_score facial_area=(x, y, w, h)
)
resp_obj["is_real"] = is_real
resp_obj["antispoof_score"] = antispoof_score
resp_objs.append(resp_obj) img_resp_objs.append(resp_obj)
if len(resp_objs) == 0 and enforce_detection == True: all_resp_objs.append(img_resp_objs)
raise ValueError(
f"Exception while extracting faces from {img_name}."
"Consider to set enforce_detection arg to False."
)
return resp_objs if not batched_input:
return all_resp_objs[0]
return all_resp_objs
def detect_faces( def detect_faces(
detector_backend: str, detector_backend: str,
img: np.ndarray, img: Union[np.ndarray, List[np.ndarray]],
align: bool = True, align: bool = True,
expand_percentage: int = 0, expand_percentage: int = 0,
max_faces: Optional[int] = None, max_faces: Optional[int] = None,
) -> List[DetectedFace]: ) -> Union[List[List[DetectedFace]], List[DetectedFace]]:
""" """
Detect face(s) from a given image Detect face(s) from a given image or list of images
Args: Args:
detector_backend (str): detector name detector_backend (str): detector name
img (np.ndarray): pre-loaded image img (np.ndarray or List[np.ndarray]): pre-loaded image or list of images
align (bool): enable or disable alignment after detection align (bool): enable or disable alignment after detection
expand_percentage (int): expand detected facial area with a percentage (default is 0). expand_percentage (int): expand detected facial area with a percentage (default is 0).
Returns: Returns:
results (List[DetectedFace]): A list of DetectedFace objects results (Union[List[List[DetectedFace]], List[DetectedFace]]):
A list of lists of DetectedFace objects or a list of DetectedFace objects
where each object contains: where each object contains:
- img (np.ndarray): The detected face as a NumPy array. - img (np.ndarray): The detected face as a NumPy array.
@ -219,53 +244,113 @@ def detect_faces(
- confidence (float): The confidence score associated with the detected face. - confidence (float): The confidence score associated with the detected face.
""" """
height, width, _ = img.shape batched_input = (
(
isinstance(img, np.ndarray) and
img.ndim == 4
) or isinstance(img, list)
)
if not batched_input:
imgs = [img]
elif isinstance(img, np.ndarray):
imgs = [img[i] for i in range(img.shape[0])]
else:
imgs = img
if detector_backend == "skip":
all_face_objs = [
[
DetectedFace(
img=single_img,
facial_area=FacialAreaRegion(
x=0, y=0, w=single_img.shape[1], h=single_img.shape[0]
),
confidence=0,
)
]
for single_img in imgs
]
if not batched_input:
all_face_objs = all_face_objs[0]
return all_face_objs
face_detector: Detector = modeling.build_model( face_detector: Detector = modeling.build_model(
task="face_detector", model_name=detector_backend task="face_detector", model_name=detector_backend
) )
# validate expand percentage score preprocessed_images = []
if expand_percentage < 0: width_borders = []
logger.warn( height_borders = []
f"Expand percentage cannot be negative but you set it to {expand_percentage}." for single_img in imgs:
"Overwritten it to 0." height, width, _ = single_img.shape
)
expand_percentage = 0
# If faces are close to the upper boundary, alignment move them outside # validate expand percentage score
# Add a black border around an image to avoid this. if expand_percentage < 0:
height_border = int(0.5 * height) logger.warn(
width_border = int(0.5 * width) f"Expand percentage cannot be negative but you set it to {expand_percentage}."
if align is True: "Overwritten it to 0."
img = cv2.copyMakeBorder( )
img, expand_percentage = 0
height_border,
height_border,
width_border,
width_border,
cv2.BORDER_CONSTANT,
value=[0, 0, 0], # Color of the border (black)
)
# find facial areas of given image # If faces are close to the upper boundary, alignment move them outside
facial_areas = face_detector.detect_faces(img) # Add a black border around an image to avoid this.
height_border = int(0.5 * height)
width_border = int(0.5 * width)
if align is True:
single_img = cv2.copyMakeBorder(
single_img,
height_border,
height_border,
width_border,
width_border,
cv2.BORDER_CONSTANT,
value=[0, 0, 0], # Color of the border (black)
)
if max_faces is not None and max_faces < len(facial_areas): preprocessed_images.append(single_img)
facial_areas = nlargest( width_borders.append(width_border)
max_faces, facial_areas, key=lambda facial_area: facial_area.w * facial_area.h height_borders.append(height_border)
)
return [ # Detect faces in all preprocessed images
extract_face( all_facial_areas = face_detector.detect_faces(preprocessed_images)
facial_area=facial_area,
img=img, all_detected_faces = []
align=align, for (
expand_percentage=expand_percentage, single_img,
width_border=width_border, facial_areas,
height_border=height_border, width_border,
) height_border
for facial_area in facial_areas ) in zip(
] preprocessed_images,
all_facial_areas,
width_borders,
height_borders
):
if not isinstance(facial_areas, list):
facial_areas = [facial_areas]
if max_faces is not None and max_faces < len(facial_areas):
facial_areas = nlargest(
max_faces, facial_areas, key=lambda facial_area: facial_area.w * facial_area.h
)
detected_faces = [
extract_face(
facial_area=facial_area,
img=single_img,
align=align,
expand_percentage=expand_percentage,
width_border=width_border,
height_border=height_border,
)
for facial_area in facial_areas if isinstance(facial_area, FacialAreaRegion)
]
all_detected_faces.append(detected_faces)
if not batched_input:
return all_detected_faces[0]
return all_detected_faces
def extract_face( def extract_face(

View File

@ -23,6 +23,10 @@ def test_different_detectors():
for detector in detectors: for detector in detectors:
img_objs = DeepFace.extract_faces(img_path=img_path, detector_backend=detector) img_objs = DeepFace.extract_faces(img_path=img_path, detector_backend=detector)
# Check return type for non-batch input
assert isinstance(img_objs, list) and all(isinstance(obj, dict) for obj in img_objs)
for img_obj in img_objs: for img_obj in img_objs:
assert "face" in img_obj.keys() assert "face" in img_obj.keys()
assert "facial_area" in img_obj.keys() assert "facial_area" in img_obj.keys()
@ -79,6 +83,169 @@ def test_different_detectors():
logger.info(f"✅ extract_faces for {detector} backend test is done") logger.info(f"✅ extract_faces for {detector} backend test is done")
@pytest.mark.parametrize("detector_backend", [
# "opencv",
"ssd",
"mtcnn",
"retinaface",
"yunet",
"centerface",
# optional
# "yolov11s",
# "mediapipe",
# "dlib",
])
def test_batch_extract_faces(detector_backend):
# Relative tolerance for comparing floating-point values
rtol = 0.03
img_paths = [
"dataset/img2.jpg",
"dataset/img3.jpg",
"dataset/img11.jpg",
"dataset/couple.jpg"
]
expected_num_faces = [1, 1, 1, 2]
# Extract faces one by one
imgs_objs_individual = [
DeepFace.extract_faces(
img_path=img_path,
detector_backend=detector_backend,
align=True,
) for img_path in img_paths
]
# Check that individual extraction returns a list of faces
for img_objs_individual in imgs_objs_individual:
assert isinstance(img_objs_individual, list)
assert all(isinstance(face, dict) for face in img_objs_individual)
# Check that the individual extraction results match the expected number of faces
for img_objs_individual, expected_faces in zip(imgs_objs_individual, expected_num_faces):
assert len(img_objs_individual) == expected_faces
# Extract faces in batch
imgs_objs_batch = DeepFace.extract_faces(
img_path=img_paths,
detector_backend=detector_backend,
align=True,
)
# Check that the batch extraction returned the expected number of face lists
assert len(imgs_objs_batch) == len(img_paths)
# Check that each face list has the expected number of faces
for i, expected_faces in enumerate(expected_num_faces):
assert len(imgs_objs_batch[i]) == expected_faces
# Check that the individual extraction results match the batch extraction results
for img_objs_individual, img_objs_batch in zip(imgs_objs_individual, imgs_objs_batch):
assert len(img_objs_batch) == len(img_objs_individual), (
"Batch and individual extraction results should have the same number of detected faces"
)
for img_obj_individual, img_obj_batch in zip(img_objs_individual, img_objs_batch):
for key in img_obj_individual["facial_area"]:
if isinstance(img_obj_individual["facial_area"][key], tuple):
for ind_val, batch_val in zip(
img_obj_individual["facial_area"][key],
img_obj_batch["facial_area"][key]
):
# Ensure the difference between individual and batch values
# is within rtol% of the individual value
assert abs(ind_val - batch_val) <= rtol * ind_val
elif (
isinstance(img_obj_individual["facial_area"][key], int) or
isinstance(img_obj_individual["facial_area"][key], float)
):
# Ensure the difference between individual and batch values
# is within rtol% of the individual value
assert abs(
img_obj_individual["facial_area"][key] -
img_obj_batch["facial_area"][key]
) <= rtol * img_obj_individual["facial_area"][key]
# Ensure the confidence difference is within rtol% of the individual confidence
assert abs(
img_obj_individual["confidence"] -
img_obj_batch["confidence"]
) <= rtol * img_obj_individual["confidence"]
@pytest.mark.parametrize("detector_backend", [
"opencv",
"ssd",
"mtcnn",
"retinaface",
"yunet",
# "centerface",
# optional
# "yolov11s",
# "mediapipe",
# "dlib",
])
def test_batch_extract_faces_with_nparray(detector_backend):
img_paths = [
"dataset/img2.jpg",
"dataset/img3.jpg",
"dataset/img11.jpg",
"dataset/couple.jpg"
]
imgs = [
cv2.resize(image_utils.load_image(img_path)[0], (1920, 1080))
for img_path in img_paths
]
expected_num_faces = [1, 1, 1, 2]
# load images as numpy arrays
imgs_batch = np.stack(imgs, axis=0)
# extract faces in batch of numpy arrays
imgs_objs_batch = DeepFace.extract_faces(
img_path=imgs_batch,
detector_backend=detector_backend,
align=True,
enforce_detection=False,
)
# Check return type for batch input
assert (
isinstance(imgs_objs_batch, list) and
all(
isinstance(obj, list) and
all(isinstance(face, dict) for face in obj)
for obj in imgs_objs_batch
)
)
# Check that the batch extraction returned the expected number of face lists
assert len(imgs_objs_batch) == len(img_paths)
for img_objs_batch, img_expected_num_faces in zip(imgs_objs_batch, expected_num_faces):
assert len(img_objs_batch) == img_expected_num_faces
# extract faces in batch of paths
imgs_objs_batch_paths = DeepFace.extract_faces(
img_path=imgs,
detector_backend=detector_backend,
align=True,
enforce_detection=False,
)
# compare results
for img_objs_batch, img_objs_batch_paths in zip(imgs_objs_batch, imgs_objs_batch_paths):
assert len(img_objs_batch) == len(img_objs_batch_paths), (
"Batch and individual extraction results should have the same number of detected faces"
)
def test_batch_extract_faces_single_image():
img_path = "dataset/couple.jpg"
imgs_objs_batch = DeepFace.extract_faces(
img_path=[img_path],
align=True,
)
assert len(imgs_objs_batch) == 1 and isinstance(imgs_objs_batch[0], list)
assert [isinstance(obj, dict) for obj in imgs_objs_batch[0]]
def test_backends_for_enforced_detection_with_non_facial_inputs(): def test_backends_for_enforced_detection_with_non_facial_inputs():
black_img = np.zeros([224, 224, 3]) black_img = np.zeros([224, 224, 3])
for detector in detectors: for detector in detectors: