diff --git a/README.md b/README.md index c51652a..822f298 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ models = [ "ArcFace", "Dlib", "SFace", - "GhostFaceNet", + "GhostFaceNet" ] #face verification @@ -223,6 +223,9 @@ backends = [ 'retinaface', 'mediapipe', 'yolov8', + 'yolov11s', + 'yolov11n', + 'yolov11m', 'yunet', 'centerface', ] diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index af5245f..6eb31ac 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -54,10 +54,10 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any: Args: model_name (str): model identifier - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - ArcFace, SFace, GhostFaceNet for face recognition + ArcFace, SFace and GhostFaceNet for face recognition - Age, Gender, Emotion, Race for facial attributes - - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yunet, - fastmtcnn or centerface for face detectors + - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n, + yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors - Fasnet for spoofing task (str): facial_recognition, facial_attribute, face_detector, spoofing default is facial_recognition @@ -68,18 +68,18 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any: def verify( - img1_path: Union[str, np.ndarray, List[float]], - img2_path: Union[str, np.ndarray, List[float]], - model_name: str = "VGG-Face", - detector_backend: str = "opencv", - distance_metric: str = "cosine", - enforce_detection: bool = True, - align: bool = True, - expand_percentage: int = 0, - normalization: str = "base", - silent: bool = False, - threshold: Optional[float] = None, - anti_spoofing: bool = False, + img1_path: Union[str, np.ndarray, List[float]], + img2_path: Union[str, np.ndarray, List[float]], + model_name: str = "VGG-Face", + detector_backend: str = "opencv", + distance_metric: str = "cosine", + enforce_detection: bool = True, + align: bool = True, + expand_percentage: int = 0, + normalization: str = "base", + silent: bool = False, + threshold: Optional[float] = None, + anti_spoofing: bool = False, ) -> Dict[str, Any]: """ Verify if an image pair represents the same person or different persons. @@ -96,8 +96,8 @@ def verify( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -164,14 +164,14 @@ def verify( def analyze( - img_path: Union[str, np.ndarray], - actions: Union[tuple, list] = ("emotion", "age", "gender", "race"), - enforce_detection: bool = True, - detector_backend: str = "opencv", - align: bool = True, - expand_percentage: int = 0, - silent: bool = False, - anti_spoofing: bool = False, + img_path: Union[str, np.ndarray], + actions: Union[tuple, list] = ("emotion", "age", "gender", "race"), + enforce_detection: bool = True, + detector_backend: str = "opencv", + align: bool = True, + expand_percentage: int = 0, + silent: bool = False, + anti_spoofing: bool = False, ) -> List[Dict[str, Any]]: """ Analyze facial attributes such as age, gender, emotion, and race in the provided image. @@ -187,8 +187,8 @@ def analyze( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -263,20 +263,20 @@ def analyze( def find( - img_path: Union[str, np.ndarray], - db_path: str, - model_name: str = "VGG-Face", - distance_metric: str = "cosine", - enforce_detection: bool = True, - detector_backend: str = "opencv", - align: bool = True, - expand_percentage: int = 0, - threshold: Optional[float] = None, - normalization: str = "base", - silent: bool = False, - refresh_database: bool = True, - anti_spoofing: bool = False, - batched: bool = False, + img_path: Union[str, np.ndarray], + db_path: str, + model_name: str = "VGG-Face", + distance_metric: str = "cosine", + enforce_detection: bool = True, + detector_backend: str = "opencv", + align: bool = True, + expand_percentage: int = 0, + threshold: Optional[float] = None, + normalization: str = "base", + silent: bool = False, + refresh_database: bool = True, + anti_spoofing: bool = False, + batched: bool = False, ) -> Union[List[pd.DataFrame], List[List[Dict[str, Any]]]]: """ Identify individuals in a database @@ -298,8 +298,8 @@ def find( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). align (boolean): Perform alignment based on the eye positions (default is True). @@ -369,15 +369,15 @@ def find( def represent( - img_path: Union[str, np.ndarray], - model_name: str = "VGG-Face", - enforce_detection: bool = True, - detector_backend: str = "opencv", - align: bool = True, - expand_percentage: int = 0, - normalization: str = "base", - anti_spoofing: bool = False, - max_faces: Optional[int] = None, + img_path: Union[str, np.ndarray], + model_name: str = "VGG-Face", + enforce_detection: bool = True, + detector_backend: str = "opencv", + align: bool = True, + expand_percentage: int = 0, + normalization: str = "base", + anti_spoofing: bool = False, + max_faces: Optional[int] = None, ) -> List[Dict[str, Any]]: """ Represent facial images as multi-dimensional vector embeddings. @@ -396,8 +396,8 @@ def represent( (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). align (boolean): Perform alignment based on the eye positions (default is True). @@ -441,15 +441,15 @@ def represent( def stream( - db_path: str = "", - model_name: str = "VGG-Face", - detector_backend: str = "opencv", - distance_metric: str = "cosine", - enable_face_analysis: bool = True, - source: Any = 0, - time_threshold: int = 5, - frame_threshold: int = 5, - anti_spoofing: bool = False, + db_path: str = "", + model_name: str = "VGG-Face", + detector_backend: str = "opencv", + distance_metric: str = "cosine", + enable_face_analysis: bool = True, + source: Any = 0, + time_threshold: int = 5, + frame_threshold: int = 5, + anti_spoofing: bool = False, ) -> None: """ Run real time face recognition and facial attribute analysis @@ -462,8 +462,8 @@ def stream( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -499,15 +499,15 @@ def stream( def extract_faces( - img_path: Union[str, np.ndarray], - detector_backend: str = "opencv", - enforce_detection: bool = True, - align: bool = True, - expand_percentage: int = 0, - grayscale: bool = False, - color_face: str = "rgb", - normalize_face: bool = True, - anti_spoofing: bool = False, + img_path: Union[str, np.ndarray], + detector_backend: str = "opencv", + enforce_detection: bool = True, + align: bool = True, + expand_percentage: int = 0, + grayscale: bool = False, + color_face: str = "rgb", + normalize_face: bool = True, + anti_spoofing: bool = False, ) -> List[Dict[str, Any]]: """ Extract faces from a given image @@ -517,8 +517,8 @@ def extract_faces( as a string, numpy array (BGR), or base64 encoded images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). enforce_detection (boolean): If no face is detected in an image, raise an exception. Set to False to avoid the exception for low-resolution images (default is True). @@ -584,11 +584,11 @@ def cli() -> None: def detectFace( - img_path: Union[str, np.ndarray], - target_size: tuple = (224, 224), - detector_backend: str = "opencv", - enforce_detection: bool = True, - align: bool = True, + img_path: Union[str, np.ndarray], + target_size: tuple = (224, 224), + detector_backend: str = "opencv", + enforce_detection: bool = True, + align: bool = True, ) -> Union[np.ndarray, None]: """ Deprecated face detection function. Use extract_faces for same functionality. @@ -601,8 +601,8 @@ def detectFace( added to resize the image (default is (224, 224)). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). enforce_detection (boolean): If no face is detected in an image, raise an exception. Set to False to avoid the exception for low-resolution images (default is True). diff --git a/deepface/api/src/modules/core/routes.py b/deepface/api/src/modules/core/routes.py index 9cb2e74..042f4ad 100644 --- a/deepface/api/src/modules/core/routes.py +++ b/deepface/api/src/modules/core/routes.py @@ -72,7 +72,9 @@ def extract_image_from_request(img_key: str) -> Union[str, np.ndarray]: @blueprint.route("/represent", methods=["POST"]) def represent(): - input_args = request.get_json() or request.form.to_dict() + input_args = (request.is_json and request.get_json()) or ( + request.form and request.form.to_dict() + ) try: img = extract_image_from_request("img") @@ -96,7 +98,9 @@ def represent(): @blueprint.route("/verify", methods=["POST"]) def verify(): - input_args = request.get_json() or request.form.to_dict() + input_args = (request.is_json and request.get_json()) or ( + request.form and request.form.to_dict() + ) try: img1 = extract_image_from_request("img1") @@ -126,7 +130,9 @@ def verify(): @blueprint.route("/analyze", methods=["POST"]) def analyze(): - input_args = request.get_json() or request.form.to_dict() + input_args = (request.is_json and request.get_json()) or ( + request.form and request.form.to_dict() + ) try: img = extract_image_from_request("img") diff --git a/deepface/commons/weight_utils.py b/deepface/commons/weight_utils.py index d6770c0..dfac1fa 100644 --- a/deepface/commons/weight_utils.py +++ b/deepface/commons/weight_utils.py @@ -128,8 +128,9 @@ def download_all_models_in_one_shot() -> None: WEIGHTS_URL as SSD_WEIGHTS, ) from deepface.models.face_detection.Yolo import ( - WEIGHT_URL as YOLOV8_WEIGHTS, - WEIGHT_NAME as YOLOV8_WEIGHT_NAME, + WEIGHT_URLS as YOLO_WEIGHTS, + WEIGHT_NAMES as YOLO_WEIGHT_NAMES, + YoloModel ) from deepface.models.face_detection.YuNet import WEIGHTS_URL as YUNET_WEIGHTS from deepface.models.face_detection.Dlib import WEIGHTS_URL as DLIB_FD_WEIGHTS @@ -162,8 +163,20 @@ def download_all_models_in_one_shot() -> None: SSD_MODEL, SSD_WEIGHTS, { - "filename": YOLOV8_WEIGHT_NAME, - "url": YOLOV8_WEIGHTS, + "filename": YOLO_WEIGHT_NAMES[YoloModel.V8N.value], + "url": YOLO_WEIGHTS[YoloModel.V8N.value], + }, + { + "filename": YOLO_WEIGHT_NAMES[YoloModel.V11N.value], + "url": YOLO_WEIGHTS[YoloModel.V11N.value], + }, + { + "filename": YOLO_WEIGHT_NAMES[YoloModel.V11S.value], + "url": YOLO_WEIGHTS[YoloModel.V11S.value], + }, + { + "filename": YOLO_WEIGHT_NAMES[YoloModel.V11M.value], + "url": YOLO_WEIGHTS[YoloModel.V11M.value], }, YUNET_WEIGHTS, DLIB_FD_WEIGHTS, diff --git a/deepface/models/face_detection/CenterFace.py b/deepface/models/face_detection/CenterFace.py index d8e08bd..b8fdf6b 100644 --- a/deepface/models/face_detection/CenterFace.py +++ b/deepface/models/face_detection/CenterFace.py @@ -46,7 +46,7 @@ class CenterFaceClient(Detector): """ resp = [] - threshold = float(os.getenv("CENTERFACE_THRESHOLD", "0.80")) + threshold = float(os.getenv("CENTERFACE_THRESHOLD", "0.35")) # BUG: model causes problematic results from 2nd call if it is not flushed # detections, landmarks = self.model.forward( diff --git a/deepface/models/face_detection/FastMtCnn.py b/deepface/models/face_detection/FastMtCnn.py index bc792a4..5259036 100644 --- a/deepface/models/face_detection/FastMtCnn.py +++ b/deepface/models/face_detection/FastMtCnn.py @@ -92,4 +92,4 @@ def xyxy_to_xywh(regions: Union[list, tuple]) -> tuple: x, y, x_plus_w, y_plus_h = regions[0], regions[1], regions[2], regions[3] w = x_plus_w - x h = y_plus_h - y - return (x, y, w, h) + return (int(x), int(y), int(w), int(h)) diff --git a/deepface/models/face_detection/Yolo.py b/deepface/models/face_detection/Yolo.py index 77dd09b..233f088 100644 --- a/deepface/models/face_detection/Yolo.py +++ b/deepface/models/face_detection/Yolo.py @@ -1,29 +1,45 @@ # built-in dependencies import os -from typing import Any, List +from typing import List, Any +from enum import Enum # 3rd party dependencies import numpy as np # project dependencies from deepface.models.Detector import Detector, FacialAreaRegion -from deepface.commons import weight_utils from deepface.commons.logger import Logger +from deepface.commons import weight_utils logger = Logger() + +class YoloModel(Enum): + V8N = 0 + V11N = 1 + V11S = 2 + V11M = 3 + + # Model's weights paths -WEIGHT_NAME = "yolov8n-face.pt" +WEIGHT_NAMES = ["yolov8n-face.pt", + "yolov11n-face.pt", + "yolov11s-face.pt", + "yolov11m-face.pt"] # Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB -WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb" +WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11s-face.pt", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"] -class YoloClient(Detector): - def __init__(self): - self.model = self.build_model() +class YoloDetectorClient(Detector): + def __init__(self, model: YoloModel): + super().__init__() + self.model = self.build_model(model) - def build_model(self) -> Any: + def build_model(self, model: YoloModel) -> Any: """ Build a yolo detector model Returns: @@ -40,7 +56,7 @@ class YoloClient(Detector): ) from e weight_file = weight_utils.download_weights_if_necessary( - file_name=WEIGHT_NAME, source_url=WEIGHT_URL + file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value] ) # Return face_detector @@ -69,21 +85,27 @@ class YoloClient(Detector): # For each face, extract the bounding box, the landmarks and confidence for result in results: - if result.boxes is None or result.keypoints is None: + if result.boxes is None: continue # Extract the bounding box and the confidence x, y, w, h = result.boxes.xywh.tolist()[0] confidence = result.boxes.conf.tolist()[0] - # right_eye_conf = result.keypoints.conf[0][0] - # left_eye_conf = result.keypoints.conf[0][1] - right_eye = result.keypoints.xy[0][0].tolist() - left_eye = result.keypoints.xy[0][1].tolist() + right_eye = None + left_eye = None - # eyes are list of float, need to cast them tuple of int - left_eye = tuple(int(i) for i in left_eye) - right_eye = tuple(int(i) for i in right_eye) + # yolo-facev8 is detecting eyes through keypoints, + # while for v11 keypoints are always None + if result.keypoints is not None: + # right_eye_conf = result.keypoints.conf[0][0] + # left_eye_conf = result.keypoints.conf[0][1] + right_eye = result.keypoints.xy[0][0].tolist() + left_eye = result.keypoints.xy[0][1].tolist() + + # eyes are list of float, need to cast them tuple of int + left_eye = tuple(int(i) for i in left_eye) + right_eye = tuple(int(i) for i in right_eye) x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h) facial_area = FacialAreaRegion( @@ -98,3 +120,23 @@ class YoloClient(Detector): resp.append(facial_area) return resp + + +class YoloDetectorClientV8n(YoloDetectorClient): + def __init__(self): + super().__init__(YoloModel.V8N) + + +class YoloDetectorClientV11n(YoloDetectorClient): + def __init__(self): + super().__init__(YoloModel.V11N) + + +class YoloDetectorClientV11s(YoloDetectorClient): + def __init__(self): + super().__init__(YoloModel.V11S) + + +class YoloDetectorClientV11m(YoloDetectorClient): + def __init__(self): + super().__init__(YoloModel.V11M) diff --git a/deepface/models/facial_recognition/Facenet.py b/deepface/models/facial_recognition/Facenet.py index b75e620..15a5ba3 100644 --- a/deepface/models/facial_recognition/Facenet.py +++ b/deepface/models/facial_recognition/Facenet.py @@ -64,7 +64,7 @@ class FaceNet128dClient(FacialRecognition): class FaceNet512dClient(FacialRecognition): """ - FaceNet-1512d model class + FaceNet-512d model class """ def __init__(self): diff --git a/deepface/modules/demography.py b/deepface/modules/demography.py index b68314b..2258c1e 100644 --- a/deepface/modules/demography.py +++ b/deepface/modules/demography.py @@ -35,8 +35,8 @@ def analyze( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). diff --git a/deepface/modules/detection.py b/deepface/modules/detection.py index 17bd5d9..221e1d2 100644 --- a/deepface/modules/detection.py +++ b/deepface/modules/detection.py @@ -38,8 +38,8 @@ def extract_faces( as a string, numpy array (BGR), or base64 encoded images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv) + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv) enforce_detection (boolean): If no face is detected in an image, raise an exception. Default is True. Set to False to avoid the exception for low-resolution images. @@ -255,7 +255,7 @@ def detect_faces( ) return [ - expand_and_align_face( + extract_face( facial_area=facial_area, img=img, align=align, @@ -267,7 +267,7 @@ def detect_faces( ] -def expand_and_align_face( +def extract_face( facial_area: FacialAreaRegion, img: np.ndarray, align: bool, @@ -301,15 +301,32 @@ def expand_and_align_face( detected_face = img[int(y) : int(y + h), int(x) : int(x + w)] # align original image, then find projection of detected face area after alignment if align is True: # and left_eye is not None and right_eye is not None: - aligned_img, angle = align_img_wrt_eyes(img=img, left_eye=left_eye, right_eye=right_eye) + # we were aligning the original image before, but this comes with an extra cost + # instead we now focus on the facial area with a margin + # and align it instead of original image to decrese the cost + sub_img, relative_x, relative_y = extract_sub_image(img=img, facial_area=(x, y, w, h)) + + aligned_sub_img, angle = align_img_wrt_eyes( + img=sub_img, left_eye=left_eye, right_eye=right_eye + ) rotated_x1, rotated_y1, rotated_x2, rotated_y2 = project_facial_area( - facial_area=(x, y, x + w, y + h), angle=angle, size=(img.shape[0], img.shape[1]) + facial_area=( + relative_x, + relative_y, + relative_x + w, + relative_y + h, + ), + angle=angle, + size=(sub_img.shape[0], sub_img.shape[1]), ) - detected_face = aligned_img[ + detected_face = aligned_sub_img[ int(rotated_y1) : int(rotated_y2), int(rotated_x1) : int(rotated_x2) ] + # do not spend memory for these temporary variables anymore + del aligned_sub_img, sub_img + # restore x, y, le and re before border added x = x - width_border y = y - height_border @@ -339,14 +356,66 @@ def expand_and_align_face( mouth_left=mouth_left, mouth_right=mouth_right, ), - confidence=confidence, + confidence=confidence or 0, ) +def extract_sub_image( + img: np.ndarray, facial_area: Tuple[int, int, int, int] +) -> Tuple[np.ndarray, int, int]: + """ + Get the sub image with given facial area while expanding the facial region + to ensure alignment does not shift the face outside the image. + + This function doubles the height and width of the face region, + and adds black pixels if necessary. + + Args: + - img (np.ndarray): pre-loaded image with detected face + - facial_area (tuple of int): Representing the (x, y, w, h) of the facial area. + + Returns: + - extracted_face (np.ndarray): expanded facial image + - relative_x (int): adjusted x-coordinates relative to the expanded region + - relative_y (int): adjusted y-coordinates relative to the expanded region + """ + x, y, w, h = facial_area + relative_x = int(0.5 * w) + relative_y = int(0.5 * h) + + # calculate expanded coordinates + x1, y1 = x - relative_x, y - relative_y + x2, y2 = x + w + relative_x, y + h + relative_y + + # most of the time, the expanded region fits inside the image + if x1 >= 0 and y1 >= 0 and x2 <= img.shape[1] and y2 <= img.shape[0]: + return img[y1:y2, x1:x2], relative_x, relative_y + + # but sometimes, we need to add black pixels + # ensure the coordinates are within bounds + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(img.shape[1], x2), min(img.shape[0], y2) + cropped_region = img[y1:y2, x1:x2] + + # create a black image + extracted_face = np.zeros( + (h + 2 * relative_y, w + 2 * relative_x, img.shape[2]), dtype=img.dtype + ) + + # map the cropped region + start_x = max(0, relative_x - x) + start_y = max(0, relative_y - y) + extracted_face[ + start_y : start_y + cropped_region.shape[0], start_x : start_x + cropped_region.shape[1] + ] = cropped_region + + return extracted_face, relative_x, relative_y + + def align_img_wrt_eyes( img: np.ndarray, - left_eye: Union[list, tuple], - right_eye: Union[list, tuple], + left_eye: Optional[Union[list, tuple]], + right_eye: Optional[Union[list, tuple]], ) -> Tuple[np.ndarray, float]: """ Align a given image horizantally with respect to their left and right eye locations diff --git a/deepface/modules/modeling.py b/deepface/modules/modeling.py index c097c92..176d9e7 100644 --- a/deepface/modules/modeling.py +++ b/deepface/modules/modeling.py @@ -11,7 +11,7 @@ from deepface.models.facial_recognition import ( SFace, Dlib, Facenet, - GhostFaceNet, + GhostFaceNet ) from deepface.models.face_detection import ( FastMtCnn, @@ -21,7 +21,7 @@ from deepface.models.face_detection import ( Dlib as DlibDetector, RetinaFace, Ssd, - Yolo, + Yolo as YoloFaceDetector, YuNet, CenterFace, ) @@ -36,10 +36,10 @@ def build_model(task: str, model_name: str) -> Any: task (str): facial_recognition, facial_attribute, face_detector, spoofing model_name (str): model identifier - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - ArcFace, SFace, GhostFaceNet for face recognition + ArcFace, SFace and GhostFaceNet for face recognition - Age, Gender, Emotion, Race for facial attributes - - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yunet, - fastmtcnn or centerface for face detectors + - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', + 'yolov11s', 'yolov11m', yunet, fastmtcnn or centerface for face detectors - Fasnet for spoofing Returns: built model class @@ -59,7 +59,7 @@ def build_model(task: str, model_name: str) -> Any: "Dlib": Dlib.DlibClient, "ArcFace": ArcFace.ArcFaceClient, "SFace": SFace.SFaceClient, - "GhostFaceNet": GhostFaceNet.GhostFaceNetClient, + "GhostFaceNet": GhostFaceNet.GhostFaceNetClient }, "spoofing": { "Fasnet": FasNet.Fasnet, @@ -77,7 +77,10 @@ def build_model(task: str, model_name: str) -> Any: "dlib": DlibDetector.DlibClient, "retinaface": RetinaFace.RetinaFaceClient, "mediapipe": MediaPipe.MediaPipeClient, - "yolov8": Yolo.YoloClient, + "yolov8": YoloFaceDetector.YoloDetectorClientV8n, + "yolov11n": YoloFaceDetector.YoloDetectorClientV11n, + "yolov11s": YoloFaceDetector.YoloDetectorClientV11s, + "yolov11m": YoloFaceDetector.YoloDetectorClientV11m, "yunet": YuNet.YuNetClient, "fastmtcnn": FastMtCnn.FastMtCnnClient, "centerface": CenterFace.CenterFaceClient, diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index df7068d..1edb430 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -54,7 +54,8 @@ def find( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n', 'yolov11s', + 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. @@ -483,7 +484,8 @@ def find_batched( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', + 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. diff --git a/deepface/modules/representation.py b/deepface/modules/representation.py index a147640..d880645 100644 --- a/deepface/modules/representation.py +++ b/deepface/modules/representation.py @@ -36,7 +36,8 @@ def represent( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', + 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. @@ -115,7 +116,7 @@ def represent( raise ValueError("Spoof detected in the given image.") img = img_obj["face"] - # rgb to bgr + # bgr to rgb img = img[:, :, ::-1] region = img_obj["facial_area"] diff --git a/deepface/modules/streaming.py b/deepface/modules/streaming.py index c1a0363..cc44783 100644 --- a/deepface/modules/streaming.py +++ b/deepface/modules/streaming.py @@ -42,11 +42,11 @@ def analysis( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face) detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -192,8 +192,8 @@ def search_identity( model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). Returns: @@ -374,8 +374,8 @@ def grab_facial_areas( Args: img (np.ndarray): image itself detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). threshold (int): threshold for facial area, discard smaller ones Returns result (list): list of tuple with x, y, w and h coordinates @@ -443,8 +443,8 @@ def perform_facial_recognition( db_path (string): Path to the folder containing image files. All detected faces in the database will be considered in the decision-making process. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', + 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, diff --git a/deepface/modules/verification.py b/deepface/modules/verification.py index 540b63b..43c3ba9 100644 --- a/deepface/modules/verification.py +++ b/deepface/modules/verification.py @@ -47,8 +47,8 @@ def verify( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' - (default is opencv) + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv) distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). diff --git a/tests/test_extract_faces.py b/tests/test_extract_faces.py index ba05ab4..262d22d 100644 --- a/tests/test_extract_faces.py +++ b/tests/test_extract_faces.py @@ -119,25 +119,31 @@ def image_to_base64(image_path): def test_facial_coordinates_are_in_borders(): + detectors = ["retinaface", "mtcnn"] + expected_faces = [7, 6] + img_path = "dataset/selfie-many-people.jpg" img = cv2.imread(img_path) height, width, _ = img.shape - results = DeepFace.extract_faces(img_path=img_path) + for i, detector_backend in enumerate(detectors): + results = DeepFace.extract_faces(img_path=img_path, detector_backend=detector_backend) - assert len(results) > 0 + # this is a hard example, mtcnn can detect 6 and retinaface can detect 7 faces + # be sure all those faces detected. any change in detection module can break this. + assert len(results) == expected_faces[i] - for result in results: - facial_area = result["facial_area"] + for result in results: + facial_area = result["facial_area"] - x = facial_area["x"] - y = facial_area["y"] - w = facial_area["w"] - h = facial_area["h"] + x = facial_area["x"] + y = facial_area["y"] + w = facial_area["w"] + h = facial_area["h"] - assert x >= 0 - assert y >= 0 - assert x + w < width - assert y + h < height + assert x >= 0 + assert y >= 0 + assert x + w < width + assert y + h < height - logger.info("✅ facial area coordinates are all in image borders") + logger.info(f"✅ facial area coordinates are all in image borders for {detector_backend}") diff --git a/tests/visual-test.py b/tests/visual-test.py index 9149bc5..9dd8986 100644 --- a/tests/visual-test.py +++ b/tests/visual-test.py @@ -21,7 +21,7 @@ model_names = [ "Dlib", "ArcFace", "SFace", - "GhostFaceNet", + "GhostFaceNet" ] detector_backends = [ @@ -34,6 +34,9 @@ detector_backends = [ "retinaface", "yunet", "yolov8", + "yolov11n", + "yolov11s", + "yolov11m", "centerface", ]