Merge 3111a2895a0e2561403ea193dd2f75f069160d57 into df1b6ab6fe431f126afe7ff3baa7fe394a967f0a

This commit is contained in:
galthran-wq 2025-05-17 18:30:50 -07:00 committed by GitHub
commit ab84683bbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 801 additions and 326 deletions

View File

@ -521,7 +521,7 @@ def stream(
def extract_faces(
img_path: Union[str, np.ndarray, IO[bytes]],
img_path: Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]],
detector_backend: str = "opencv",
enforce_detection: bool = True,
align: bool = True,
@ -530,14 +530,14 @@ def extract_faces(
color_face: str = "rgb",
normalize_face: bool = True,
anti_spoofing: bool = False,
) -> List[Dict[str, Any]]:
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""
Extract faces from a given image
Extract faces from a given image or sequence of images.
Args:
img_path (str or np.ndarray or IO[bytes]): Path to the first image. Accepts exact image path
as a string, numpy array (BGR), a file object that supports at least `.read` and is
opened in binary mode, or base64 encoded images.
img_path (Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]]):
Path(s) to the image(s). Accepts a string path, a numpy array (BGR), a file object
that supports at least `.read` and is opened in binary mode, or base64 encoded images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
@ -562,7 +562,8 @@ def extract_faces(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains:
results (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]):
A list or a list of lists of dictionaries, where each dictionary contains:
- "face" (np.ndarray): The detected face as a NumPy array.

View File

@ -1,4 +1,4 @@
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
@ -9,15 +9,20 @@ import numpy as np
# pylint: disable=unnecessary-pass, too-few-public-methods, too-many-instance-attributes
class Detector(ABC):
@abstractmethod
def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]:
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List["FacialAreaRegion"], List[List["FacialAreaRegion"]]]:
"""
Interface for detect and align face
Interface for detect and align faces in a batch of images
Args:
img (np.ndarray): pre-loaded image as numpy array
img (Union[np.ndarray, List[np.ndarray]]):
Pre-loaded image as numpy array or a list of those
Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
results (Union[List[List[FacialAreaRegion]], List[FacialAreaRegion]]):
A list or a list of lists of FacialAreaRegion objects
where each object contains:
- facial_area (FacialAreaRegion): The facial area region represented
@ -28,6 +33,7 @@ class Detector(ABC):
pass
# pylint: disable=unnecessary-pass, too-few-public-methods, too-many-instance-attributes
@dataclass
class FacialAreaRegion:
"""

View File

@ -1,6 +1,6 @@
# built-in dependencies
import os
from typing import List
from typing import List, Union
# 3rd party dependencies
import numpy as np
@ -34,12 +34,35 @@ class CenterFaceClient(Detector):
return CenterFace(weight_path=weights_path)
def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]:
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]],
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
"""
Detect and align face with CenterFace
Args:
img (np.ndarray): pre-loaded image as numpy array
img (Union[np.ndarray, List[np.ndarray]]):
pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, single_img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args:
single_img (np.ndarray): pre-loaded image as numpy array
Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
@ -53,7 +76,7 @@ class CenterFaceClient(Detector):
# img, img.shape[0], img.shape[1], threshold=threshold
# )
detections, landmarks = self.build_model().forward(
img, img.shape[0], img.shape[1], threshold=threshold
single_img, single_img.shape[0], single_img.shape[1], threshold=threshold
)
for i, detection in enumerate(detections):

View File

@ -1,5 +1,5 @@
# built-in dependencies
from typing import List
from typing import List, Union
# 3rd party dependencies
import numpy as np
@ -47,10 +47,33 @@ class DlibClient(Detector):
detector["sp"] = sp
return detector
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]],
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
"""
Detect and align face with dlib
Args:
img (Union[np.ndarray, List[np.ndarray]]):
pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args:
img (np.ndarray): pre-loaded image as numpy array

View File

@ -17,10 +17,33 @@ class FastMtCnnClient(Detector):
def __init__(self):
self.model = self.build_model()
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]],
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
"""
Detect and align face with mtcnn
Args:
img (Union[np.ndarray, List[np.ndarray]]):
pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args:
img (np.ndarray): pre-loaded image as numpy array

View File

@ -1,6 +1,6 @@
# built-in dependencies
import os
from typing import Any, List
from typing import Any, List, Union
# 3rd party dependencies
import numpy as np
@ -43,10 +43,33 @@ class MediaPipeClient(Detector):
)
return face_detection
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]],
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
"""
Detect and align face with mediapipe
Args:
img (Union[np.ndarray, List[np.ndarray]]):
pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args:
img (np.ndarray): pre-loaded image as numpy array

View File

@ -1,5 +1,6 @@
# built-in dependencies
from typing import List
import logging
from typing import List, Union
# 3rd party dependencies
import numpy as np
@ -8,6 +9,8 @@ from mtcnn import MTCNN
# project dependencies
from deepface.models.Detector import Detector, FacialAreaRegion
logger = logging.getLogger(__name__)
# pylint: disable=too-few-public-methods
class MtCnnClient(Detector):
"""
@ -16,28 +19,43 @@ class MtCnnClient(Detector):
def __init__(self):
self.model = MTCNN()
self.supports_batch_detection = self._supports_batch_detection()
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
"""
Detect and align face with mtcnn
Detect and align faces with mtcnn for a list of images
Args:
img (np.ndarray): pre-loaded image as numpy array
imgs (Union[np.ndarray, List[np.ndarray]]):
pre-loaded image as numpy array or a list of those
Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list of FacialAreaRegion objects for a single image
or a list of lists of FacialAreaRegion objects for each image
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
resp = []
# mtcnn expects RGB but OpenCV read BGR
# img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_rgb = img[:, :, ::-1]
img_rgb = [img[:, :, ::-1] for img in img]
if self.supports_batch_detection:
detections = self.model.detect_faces(img_rgb)
else:
detections = [self.model.detect_faces(single_img) for single_img in img_rgb]
if detections is not None and len(detections) > 0:
for current_detection in detections:
for image_detections in detections:
image_resp = []
if image_detections is not None and len(image_detections) > 0:
for current_detection in image_detections:
x, y, w, h = current_detection["box"]
confidence = current_detection["confidence"]
# mtcnn detector assigns left eye with respect to the observer
@ -55,6 +73,17 @@ class MtCnnClient(Detector):
confidence=confidence,
)
resp.append(facial_area)
image_resp.append(facial_area)
resp.append(image_resp)
if not is_batched_input:
return resp[0]
return resp
def _supports_batch_detection(self) -> bool:
logger.warning(
"Batch detection is disabled for mtcnn by default "
"since the results are not consistent with single image detection. "
)
return False

View File

@ -1,6 +1,7 @@
# built-in dependencies
import os
from typing import Any, List
from typing import Any, List, Union
import logging
# 3rd party dependencies
import cv2
@ -9,6 +10,7 @@ import numpy as np
#project dependencies
from deepface.models.Detector import Detector, FacialAreaRegion
logger = logging.getLogger(__name__)
class OpenCvClient(Detector):
"""
@ -17,6 +19,7 @@ class OpenCvClient(Detector):
def __init__(self):
self.model = self.build_model()
self.supports_batch_detection = self._supports_batch_detection()
def build_model(self):
"""
@ -29,38 +32,53 @@ class OpenCvClient(Detector):
detector["eye_detector"] = self.__build_cascade("haarcascade_eye")
return detector
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
def _supports_batch_detection(self) -> bool:
logger.warning(
"Batch detection is disabled for opencv by default "
"since the results are not consistent with single image detection. "
)
return False
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
"""
Detect and align face with opencv
Args:
img (np.ndarray): pre-loaded image as numpy array
img (Union[np.ndarray, List[np.ndarray]]):
Pre-loaded image as numpy array or a list of those
Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
if isinstance(img, np.ndarray):
imgs = [img]
elif self.supports_batch_detection:
imgs = img
else:
return [self.detect_faces(single_img) for single_img in img]
batch_results = []
for single_img in imgs:
resp = []
detected_face = None
faces = []
try:
# faces = detector["face_detector"].detectMultiScale(img, 1.3, 5)
# note that, by design, opencv's haarcascade scores are >0 but not capped at 1
faces, _, scores = self.model["face_detector"].detectMultiScale3(
img, 1.1, 10, outputRejectLevels=True
single_img, 1.1, 10, outputRejectLevels=True
)
except:
pass
if len(faces) > 0:
for (x, y, w, h), confidence in zip(faces, scores):
detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = single_img[int(y):int(y + h), int(x):int(x + w)]
left_eye, right_eye = self.find_eyes(img=detected_face)
# eyes found in the detected face instead image itself
# detected face's coordinates should be added
if left_eye is not None:
left_eye = (int(x + left_eye[0]), int(y + left_eye[1]))
if right_eye is not None:
@ -77,7 +95,9 @@ class OpenCvClient(Detector):
)
resp.append(facial_area)
return resp
batch_results.append(resp)
return batch_results if len(batch_results) > 1 else batch_results[0]
def find_eyes(self, img: np.ndarray) -> tuple:
"""

View File

@ -1,5 +1,5 @@
# built-in dependencies
from typing import List
from typing import List, Union
# 3rd party dependencies
import numpy as np
@ -13,23 +13,34 @@ class RetinaFaceClient(Detector):
def __init__(self):
self.model = rf.build_model()
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
"""
Detect and align face with retinaface
Detect and align faces with retinaface in a batch of images
Args:
img (np.ndarray): pre-loaded image as numpy array
img (Union[np.ndarray, List[np.ndarray]]):
Pre-loaded image as numpy array or a list of those
Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
imgs = [img]
else:
imgs = img
batch_results = []
for single_img in imgs:
resp = []
obj = rf.detect_faces(single_img, model=self.model, threshold=0.9)
obj = rf.detect_faces(img, model=self.model, threshold=0.9)
if not isinstance(obj, dict):
return resp
if isinstance(obj, dict):
for face_idx in obj.keys():
identity = obj[face_idx]
detection = identity["facial_area"]
@ -39,16 +50,12 @@ class RetinaFaceClient(Detector):
x = detection[0]
w = detection[2] - x
# retinaface sets left and right eyes with respect to the person
left_eye = identity["landmarks"]["left_eye"]
right_eye = identity["landmarks"]["right_eye"]
left_eye = tuple(int(i) for i in identity["landmarks"]["left_eye"])
right_eye = tuple(int(i) for i in identity["landmarks"]["right_eye"])
nose = identity["landmarks"].get("nose")
mouth_right = identity["landmarks"].get("mouth_right")
mouth_left = identity["landmarks"].get("mouth_left")
# eyes are list of float, need to cast them tuple of int
left_eye = tuple(int(i) for i in left_eye)
right_eye = tuple(int(i) for i in right_eye)
if nose is not None:
nose = tuple(int(i) for i in nose)
if mouth_right is not None:
@ -73,4 +80,8 @@ class RetinaFaceClient(Detector):
resp.append(facial_area)
return resp
batch_results.append(resp)
if not is_batched_input:
return batch_results[0]
return batch_results

View File

@ -1,5 +1,5 @@
# built-in dependencies
from typing import List
from typing import List, Union
from enum import IntEnum
# 3rd party dependencies
@ -54,28 +54,48 @@ class SsdClient(Detector):
return {"face_detector": face_detector, "opencv_module": OpenCv.OpenCvClient()}
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
"""
Detect and align face with ssd
Detect and align faces with ssd in a batch of images
Args:
img (np.ndarray): pre-loaded image as numpy array
img (Union[np.ndarray, List[np.ndarray]]):
Pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]):
A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, single_img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args:
single_img (np.ndarray): Pre-loaded image as numpy array
Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
# Because cv2.dnn.blobFromImage expects CV_8U (8-bit unsigned integer) values
if img.dtype != np.uint8:
img = img.astype(np.uint8)
if single_img.dtype != np.uint8:
single_img = single_img.astype(np.uint8)
opencv_module: OpenCv.OpenCvClient = self.model["opencv_module"]
target_size = (300, 300)
original_size = img.shape
current_img = cv2.resize(img, target_size)
original_size = single_img.shape
current_img = cv2.resize(single_img, target_size)
aspect_ratio_x = original_size[1] / target_size[1]
aspect_ratio_y = original_size[0] / target_size[0]
@ -112,7 +132,7 @@ class SsdClient(Detector):
for face in faces:
confidence = float(face[ssd_labels.confidence])
x, y, w, h = map(int, face[margins])
detected_face = img[y : y + h, x : x + w]
detected_face = single_img[y : y + h, x : x + w]
left_eye, right_eye = opencv_module.find_eyes(detected_face)
@ -133,4 +153,5 @@ class SsdClient(Detector):
confidence=confidence,
)
resp.append(facial_area)
return resp

View File

@ -1,6 +1,6 @@
# built-in dependencies
import os
from typing import List, Any
from typing import List, Any, Union
from enum import Enum
# 3rd party dependencies
@ -62,25 +62,38 @@ class YoloDetectorClient(Detector):
# Return face_detector
return YOLO(weight_file)
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
def detect_faces(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> Union[List[List[FacialAreaRegion]], List[FacialAreaRegion]]:
"""
Detect and align face with yolo
Detect and align faces in a batch of images with yolo
Args:
img (np.ndarray): pre-loaded image as numpy array
img (Union[np.ndarray, List[np.ndarray]]):
Pre-loaded image as numpy array or a list of those
Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
results (Union[List[List[FacialAreaRegion]], List[FacialAreaRegion]]):
A list of lists of FacialAreaRegion objects
for each image or a list of FacialAreaRegion objects
"""
resp = []
if not isinstance(img, list):
img = [img]
# Detect faces
results = self.model.predict(
all_results = []
# Detect faces for all images
results_list = self.model.predict(
img,
verbose=False,
show=False,
conf=float(os.getenv("YOLO_MIN_DETECTION_CONFIDENCE", "0.25")),
)[0]
)
# Iterate over each image's results
for results in results_list:
resp = []
# For each face, extract the bounding box, the landmarks and confidence
for result in results:
@ -104,9 +117,17 @@ class YoloDetectorClient(Detector):
left_eye = result.keypoints.xy[0][1].tolist()
# eyes are list of float, need to cast them tuple of int
left_eye = tuple(int(i) for i in left_eye)
right_eye = tuple(int(i) for i in right_eye)
# Ensure eyes are tuples of exactly two integers or None
left_eye = (
tuple(map(int, left_eye[:2]))
if left_eye and len(left_eye) == 2
else None
)
right_eye = (
tuple(map(int, right_eye[:2]))
if right_eye and len(right_eye) == 2
else None
)
x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h)
facial_area = FacialAreaRegion(
x=x,
@ -119,7 +140,11 @@ class YoloDetectorClient(Detector):
)
resp.append(facial_area)
return resp
all_results.append(resp)
if len(all_results) == 1:
return all_results[0]
return all_results
class YoloDetectorClientV8n(YoloDetectorClient):

View File

@ -1,6 +1,6 @@
# built-in dependencies
import os
from typing import Any, List
from typing import Any, List, Union
# 3rd party dependencies
import cv2
@ -57,10 +57,28 @@ class YuNetClient(Detector):
) from err
return face_detector
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
def detect_faces(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
"""
Detect and align face with yunet
Args:
img (Union[np.ndarray, List[np.ndarray]]): pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): A list or a list of lists of FacialAreaRegion objects
"""
is_batched_input = isinstance(img, list)
if not is_batched_input:
img = [img]
results = [self._process_single_image(single_img) for single_img in img]
if not is_batched_input:
return results[0]
return results
def _process_single_image(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args:
img (np.ndarray): pre-loaded image as numpy array

View File

@ -1,5 +1,5 @@
# built-in dependencies
from typing import Any, Dict, IO, List, Tuple, Union, Optional
from typing import Any, Dict, IO, List, Tuple, Union, Optional, Sequence
# 3rd part dependencies
from heapq import nlargest
@ -19,7 +19,7 @@ logger = Logger()
def extract_faces(
img_path: Union[str, np.ndarray, IO[bytes]],
img_path: Union[Sequence[Union[str, np.ndarray, IO[bytes]]], str, np.ndarray, IO[bytes]],
detector_backend: str = "opencv",
enforce_detection: bool = True,
align: bool = True,
@ -29,13 +29,14 @@ def extract_faces(
normalize_face: bool = True,
anti_spoofing: bool = False,
max_faces: Optional[int] = None,
) -> List[Dict[str, Any]]:
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""
Extract faces from a given image
Extract faces from a given image or list of images
Args:
img_path (str or np.ndarray or IO[bytes]): Path to the first image. Accepts exact image path
as a string, numpy array (BGR), a file object that supports at least `.read` and is
img_paths (List[str or np.ndarray or IO[bytes]] or str or np.ndarray or IO[bytes]):
Path(s) to the image(s) as a string,
numpy array (BGR), a file object that supports at least `.read` and is
opened in binary mode, or base64 encoded images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
@ -61,7 +62,8 @@ def extract_faces(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains:
results (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]):
A list or list of lists of dictionaries, where each dictionary contains:
- "face" (np.ndarray): The detected face as a NumPy array in RGB format.
@ -80,30 +82,46 @@ def extract_faces(
just available in the result only if anti_spoofing is set to True in input arguments.
"""
resp_objs = []
batched_input = (
(
isinstance(img_path, np.ndarray) and
img_path.ndim == 4
) or isinstance(img_path, list)
)
if not batched_input:
imgs_path = [img_path]
elif isinstance(img_path, np.ndarray):
imgs_path = [img_path[i] for i in range(img_path.shape[0])]
else:
imgs_path = img_path
all_images = []
img_names = []
for single_img_path in imgs_path:
# img might be path, base64 or numpy array. Convert it to numpy whatever it is.
img, img_name = image_utils.load_image(img_path)
img, img_name = image_utils.load_image(single_img_path)
if img is None:
raise ValueError(f"Exception while loading {img_name}")
height, width, _ = img.shape
all_images.append(img)
img_names.append(img_name)
base_region = FacialAreaRegion(x=0, y=0, w=width, h=height, confidence=0)
if detector_backend == "skip":
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
else:
face_objs = detect_faces(
# Run detect_faces for all images at once
all_face_objs = detect_faces(
detector_backend=detector_backend,
img=img,
img=all_images,
align=align,
expand_percentage=expand_percentage,
max_faces=max_faces,
)
# in case of no face found
all_resp_objs = []
for img, img_name, face_objs in zip(all_images, img_names, all_face_objs):
height, width, _ = img.shape
if len(face_objs) == 0 and enforce_detection is True:
if img_name is not None:
raise ValueError(
@ -118,8 +136,10 @@ def extract_faces(
)
if len(face_objs) == 0 and enforce_detection is False:
base_region = FacialAreaRegion(x=0, y=0, w=width, h=height, confidence=0)
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
img_resp_objs = []
for face_obj in face_objs:
current_img = face_obj.img
current_region = face_obj.facial_area
@ -138,7 +158,10 @@ def extract_faces(
elif color_face == "gray":
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
else:
raise ValueError(f"The color_face can be rgb, bgr or gray, but it is {color_face}.")
raise ValueError(
f"The color_face can be rgb, bgr or gray, "
f"but it is {color_face}."
)
if normalize_face:
current_img = current_img / 255 # normalize input in [0, 1]
@ -174,41 +197,43 @@ def extract_faces(
if anti_spoofing is True:
antispoof_model = modeling.build_model(task="spoofing", model_name="Fasnet")
is_real, antispoof_score = antispoof_model.analyze(img=img, facial_area=(x, y, w, h))
is_real, antispoof_score = antispoof_model.analyze(
img=img,
facial_area=(x, y, w, h)
)
resp_obj["is_real"] = is_real
resp_obj["antispoof_score"] = antispoof_score
resp_objs.append(resp_obj)
img_resp_objs.append(resp_obj)
if len(resp_objs) == 0 and enforce_detection == True:
raise ValueError(
f"Exception while extracting faces from {img_name}."
"Consider to set enforce_detection arg to False."
)
all_resp_objs.append(img_resp_objs)
return resp_objs
if not batched_input:
return all_resp_objs[0]
return all_resp_objs
def detect_faces(
detector_backend: str,
img: np.ndarray,
img: Union[np.ndarray, List[np.ndarray]],
align: bool = True,
expand_percentage: int = 0,
max_faces: Optional[int] = None,
) -> List[DetectedFace]:
) -> Union[List[List[DetectedFace]], List[DetectedFace]]:
"""
Detect face(s) from a given image
Detect face(s) from a given image or list of images
Args:
detector_backend (str): detector name
img (np.ndarray): pre-loaded image
img (np.ndarray or List[np.ndarray]): pre-loaded image or list of images
align (bool): enable or disable alignment after detection
expand_percentage (int): expand detected facial area with a percentage (default is 0).
Returns:
results (List[DetectedFace]): A list of DetectedFace objects
results (Union[List[List[DetectedFace]], List[DetectedFace]]):
A list of lists of DetectedFace objects or a list of DetectedFace objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
@ -219,11 +244,46 @@ def detect_faces(
- confidence (float): The confidence score associated with the detected face.
"""
height, width, _ = img.shape
batched_input = (
(
isinstance(img, np.ndarray) and
img.ndim == 4
) or isinstance(img, list)
)
if not batched_input:
imgs = [img]
elif isinstance(img, np.ndarray):
imgs = [img[i] for i in range(img.shape[0])]
else:
imgs = img
if detector_backend == "skip":
all_face_objs = [
[
DetectedFace(
img=single_img,
facial_area=FacialAreaRegion(
x=0, y=0, w=single_img.shape[1], h=single_img.shape[0]
),
confidence=0,
)
]
for single_img in imgs
]
if not batched_input:
all_face_objs = all_face_objs[0]
return all_face_objs
face_detector: Detector = modeling.build_model(
task="face_detector", model_name=detector_backend
)
preprocessed_images = []
width_borders = []
height_borders = []
for single_img in imgs:
height, width, _ = single_img.shape
# validate expand percentage score
if expand_percentage < 0:
logger.warn(
@ -237,8 +297,8 @@ def detect_faces(
height_border = int(0.5 * height)
width_border = int(0.5 * width)
if align is True:
img = cv2.copyMakeBorder(
img,
single_img = cv2.copyMakeBorder(
single_img,
height_border,
height_border,
width_border,
@ -247,26 +307,51 @@ def detect_faces(
value=[0, 0, 0], # Color of the border (black)
)
# find facial areas of given image
facial_areas = face_detector.detect_faces(img)
preprocessed_images.append(single_img)
width_borders.append(width_border)
height_borders.append(height_border)
# Detect faces in all preprocessed images
all_facial_areas = face_detector.detect_faces(preprocessed_images)
all_detected_faces = []
for (
single_img,
facial_areas,
width_border,
height_border
) in zip(
preprocessed_images,
all_facial_areas,
width_borders,
height_borders
):
if not isinstance(facial_areas, list):
facial_areas = [facial_areas]
if max_faces is not None and max_faces < len(facial_areas):
facial_areas = nlargest(
max_faces, facial_areas, key=lambda facial_area: facial_area.w * facial_area.h
)
return [
detected_faces = [
extract_face(
facial_area=facial_area,
img=img,
img=single_img,
align=align,
expand_percentage=expand_percentage,
width_border=width_border,
height_border=height_border,
)
for facial_area in facial_areas
for facial_area in facial_areas if isinstance(facial_area, FacialAreaRegion)
]
all_detected_faces.append(detected_faces)
if not batched_input:
return all_detected_faces[0]
return all_detected_faces
def extract_face(
facial_area: FacialAreaRegion,

View File

@ -23,6 +23,10 @@ def test_different_detectors():
for detector in detectors:
img_objs = DeepFace.extract_faces(img_path=img_path, detector_backend=detector)
# Check return type for non-batch input
assert isinstance(img_objs, list) and all(isinstance(obj, dict) for obj in img_objs)
for img_obj in img_objs:
assert "face" in img_obj.keys()
assert "facial_area" in img_obj.keys()
@ -79,6 +83,169 @@ def test_different_detectors():
logger.info(f"✅ extract_faces for {detector} backend test is done")
@pytest.mark.parametrize("detector_backend", [
# "opencv",
"ssd",
"mtcnn",
"retinaface",
"yunet",
"centerface",
# optional
# "yolov11s",
# "mediapipe",
# "dlib",
])
def test_batch_extract_faces(detector_backend):
# Relative tolerance for comparing floating-point values
rtol = 0.03
img_paths = [
"dataset/img2.jpg",
"dataset/img3.jpg",
"dataset/img11.jpg",
"dataset/couple.jpg"
]
expected_num_faces = [1, 1, 1, 2]
# Extract faces one by one
imgs_objs_individual = [
DeepFace.extract_faces(
img_path=img_path,
detector_backend=detector_backend,
align=True,
) for img_path in img_paths
]
# Check that individual extraction returns a list of faces
for img_objs_individual in imgs_objs_individual:
assert isinstance(img_objs_individual, list)
assert all(isinstance(face, dict) for face in img_objs_individual)
# Check that the individual extraction results match the expected number of faces
for img_objs_individual, expected_faces in zip(imgs_objs_individual, expected_num_faces):
assert len(img_objs_individual) == expected_faces
# Extract faces in batch
imgs_objs_batch = DeepFace.extract_faces(
img_path=img_paths,
detector_backend=detector_backend,
align=True,
)
# Check that the batch extraction returned the expected number of face lists
assert len(imgs_objs_batch) == len(img_paths)
# Check that each face list has the expected number of faces
for i, expected_faces in enumerate(expected_num_faces):
assert len(imgs_objs_batch[i]) == expected_faces
# Check that the individual extraction results match the batch extraction results
for img_objs_individual, img_objs_batch in zip(imgs_objs_individual, imgs_objs_batch):
assert len(img_objs_batch) == len(img_objs_individual), (
"Batch and individual extraction results should have the same number of detected faces"
)
for img_obj_individual, img_obj_batch in zip(img_objs_individual, img_objs_batch):
for key in img_obj_individual["facial_area"]:
if isinstance(img_obj_individual["facial_area"][key], tuple):
for ind_val, batch_val in zip(
img_obj_individual["facial_area"][key],
img_obj_batch["facial_area"][key]
):
# Ensure the difference between individual and batch values
# is within rtol% of the individual value
assert abs(ind_val - batch_val) <= rtol * ind_val
elif (
isinstance(img_obj_individual["facial_area"][key], int) or
isinstance(img_obj_individual["facial_area"][key], float)
):
# Ensure the difference between individual and batch values
# is within rtol% of the individual value
assert abs(
img_obj_individual["facial_area"][key] -
img_obj_batch["facial_area"][key]
) <= rtol * img_obj_individual["facial_area"][key]
# Ensure the confidence difference is within rtol% of the individual confidence
assert abs(
img_obj_individual["confidence"] -
img_obj_batch["confidence"]
) <= rtol * img_obj_individual["confidence"]
@pytest.mark.parametrize("detector_backend", [
"opencv",
"ssd",
"mtcnn",
"retinaface",
"yunet",
# "centerface",
# optional
# "yolov11s",
# "mediapipe",
# "dlib",
])
def test_batch_extract_faces_with_nparray(detector_backend):
img_paths = [
"dataset/img2.jpg",
"dataset/img3.jpg",
"dataset/img11.jpg",
"dataset/couple.jpg"
]
imgs = [
cv2.resize(image_utils.load_image(img_path)[0], (1920, 1080))
for img_path in img_paths
]
expected_num_faces = [1, 1, 1, 2]
# load images as numpy arrays
imgs_batch = np.stack(imgs, axis=0)
# extract faces in batch of numpy arrays
imgs_objs_batch = DeepFace.extract_faces(
img_path=imgs_batch,
detector_backend=detector_backend,
align=True,
enforce_detection=False,
)
# Check return type for batch input
assert (
isinstance(imgs_objs_batch, list) and
all(
isinstance(obj, list) and
all(isinstance(face, dict) for face in obj)
for obj in imgs_objs_batch
)
)
# Check that the batch extraction returned the expected number of face lists
assert len(imgs_objs_batch) == len(img_paths)
for img_objs_batch, img_expected_num_faces in zip(imgs_objs_batch, expected_num_faces):
assert len(img_objs_batch) == img_expected_num_faces
# extract faces in batch of paths
imgs_objs_batch_paths = DeepFace.extract_faces(
img_path=imgs,
detector_backend=detector_backend,
align=True,
enforce_detection=False,
)
# compare results
for img_objs_batch, img_objs_batch_paths in zip(imgs_objs_batch, imgs_objs_batch_paths):
assert len(img_objs_batch) == len(img_objs_batch_paths), (
"Batch and individual extraction results should have the same number of detected faces"
)
def test_batch_extract_faces_single_image():
img_path = "dataset/couple.jpg"
imgs_objs_batch = DeepFace.extract_faces(
img_path=[img_path],
align=True,
)
assert len(imgs_objs_batch) == 1 and isinstance(imgs_objs_batch[0], list)
assert [isinstance(obj, dict) for obj in imgs_objs_batch[0]]
def test_backends_for_enforced_detection_with_non_facial_inputs():
black_img = np.zeros([224, 224, 3])
for detector in detectors: