Merge pull request #1397 from nriviera/yolo

[FEATURE]: adding yolov11 into face detection portfolio
This commit is contained in:
Sefik Ilkin Serengil 2024-12-11 11:19:22 +00:00 committed by GitHub
commit a402f09bc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 198 additions and 131 deletions

View File

@ -121,7 +121,7 @@ models = [
"ArcFace", "ArcFace",
"Dlib", "Dlib",
"SFace", "SFace",
"GhostFaceNet", "GhostFaceNet"
] ]
#face verification #face verification
@ -223,6 +223,9 @@ backends = [
'retinaface', 'retinaface',
'mediapipe', 'mediapipe',
'yolov8', 'yolov8',
'yolov11s',
'yolov11n',
'yolov11m',
'yunet', 'yunet',
'centerface', 'centerface',
] ]

View File

@ -54,10 +54,10 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any:
Args: Args:
model_name (str): model identifier model_name (str): model identifier
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
ArcFace, SFace, GhostFaceNet for face recognition ArcFace, SFace and GhostFaceNet for face recognition
- Age, Gender, Emotion, Race for facial attributes - Age, Gender, Emotion, Race for facial attributes
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yunet, - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n,
fastmtcnn or centerface for face detectors yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors
- Fasnet for spoofing - Fasnet for spoofing
task (str): facial_recognition, facial_attribute, face_detector, spoofing task (str): facial_recognition, facial_attribute, face_detector, spoofing
default is facial_recognition default is facial_recognition
@ -96,8 +96,8 @@ def verify(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
@ -187,8 +187,8 @@ def analyze(
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
@ -298,8 +298,8 @@ def find(
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
align (boolean): Perform alignment based on the eye positions (default is True). align (boolean): Perform alignment based on the eye positions (default is True).
@ -396,8 +396,8 @@ def represent(
(default is True). (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
align (boolean): Perform alignment based on the eye positions (default is True). align (boolean): Perform alignment based on the eye positions (default is True).
@ -462,8 +462,8 @@ def stream(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
@ -517,8 +517,8 @@ def extract_faces(
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), or base64 encoded images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
@ -601,8 +601,8 @@ def detectFace(
added to resize the image (default is (224, 224)). added to resize the image (default is (224, 224)).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).

View File

@ -128,8 +128,9 @@ def download_all_models_in_one_shot() -> None:
WEIGHTS_URL as SSD_WEIGHTS, WEIGHTS_URL as SSD_WEIGHTS,
) )
from deepface.models.face_detection.Yolo import ( from deepface.models.face_detection.Yolo import (
WEIGHT_URL as YOLOV8_WEIGHTS, WEIGHT_URLS as YOLO_WEIGHTS,
WEIGHT_NAME as YOLOV8_WEIGHT_NAME, WEIGHT_NAMES as YOLO_WEIGHT_NAMES,
YoloModel
) )
from deepface.models.face_detection.YuNet import WEIGHTS_URL as YUNET_WEIGHTS from deepface.models.face_detection.YuNet import WEIGHTS_URL as YUNET_WEIGHTS
from deepface.models.face_detection.Dlib import WEIGHTS_URL as DLIB_FD_WEIGHTS from deepface.models.face_detection.Dlib import WEIGHTS_URL as DLIB_FD_WEIGHTS
@ -162,8 +163,20 @@ def download_all_models_in_one_shot() -> None:
SSD_MODEL, SSD_MODEL,
SSD_WEIGHTS, SSD_WEIGHTS,
{ {
"filename": YOLOV8_WEIGHT_NAME, "filename": YOLO_WEIGHT_NAMES[YoloModel.V8N.value],
"url": YOLOV8_WEIGHTS, "url": YOLO_WEIGHTS[YoloModel.V8N.value],
},
{
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11N.value],
"url": YOLO_WEIGHTS[YoloModel.V11N.value],
},
{
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11S.value],
"url": YOLO_WEIGHTS[YoloModel.V11S.value],
},
{
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11M.value],
"url": YOLO_WEIGHTS[YoloModel.V11M.value],
}, },
YUNET_WEIGHTS, YUNET_WEIGHTS,
DLIB_FD_WEIGHTS, DLIB_FD_WEIGHTS,

View File

@ -1,29 +1,45 @@
# built-in dependencies # built-in dependencies
import os import os
from typing import Any, List from typing import List, Any
from enum import Enum
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
# project dependencies # project dependencies
from deepface.models.Detector import Detector, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons import weight_utils
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
from deepface.commons import weight_utils
logger = Logger() logger = Logger()
class YoloModel(Enum):
V8N = 0
V11N = 1
V11S = 2
V11M = 3
# Model's weights paths # Model's weights paths
WEIGHT_NAME = "yolov8n-face.pt" WEIGHT_NAMES = ["yolov8n-face.pt",
"yolov11n-face.pt",
"yolov11s-face.pt",
"yolov11m-face.pt"]
# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB # Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB
WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb" WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb",
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt",
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11s-face.pt",
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"]
class YoloClient(Detector): class YoloDetectorClient(Detector):
def __init__(self): def __init__(self, model: YoloModel):
self.model = self.build_model() super().__init__()
self.model = self.build_model(model)
def build_model(self) -> Any: def build_model(self, model: YoloModel) -> Any:
""" """
Build a yolo detector model Build a yolo detector model
Returns: Returns:
@ -40,7 +56,7 @@ class YoloClient(Detector):
) from e ) from e
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name=WEIGHT_NAME, source_url=WEIGHT_URL file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value]
) )
# Return face_detector # Return face_detector
@ -69,13 +85,19 @@ class YoloClient(Detector):
# For each face, extract the bounding box, the landmarks and confidence # For each face, extract the bounding box, the landmarks and confidence
for result in results: for result in results:
if result.boxes is None or result.keypoints is None: if result.boxes is None:
continue continue
# Extract the bounding box and the confidence # Extract the bounding box and the confidence
x, y, w, h = result.boxes.xywh.tolist()[0] x, y, w, h = result.boxes.xywh.tolist()[0]
confidence = result.boxes.conf.tolist()[0] confidence = result.boxes.conf.tolist()[0]
right_eye = None
left_eye = None
# yolo-facev8 is detecting eyes through keypoints,
# while for v11 keypoints are always None
if result.keypoints is not None:
# right_eye_conf = result.keypoints.conf[0][0] # right_eye_conf = result.keypoints.conf[0][0]
# left_eye_conf = result.keypoints.conf[0][1] # left_eye_conf = result.keypoints.conf[0][1]
right_eye = result.keypoints.xy[0][0].tolist() right_eye = result.keypoints.xy[0][0].tolist()
@ -98,3 +120,23 @@ class YoloClient(Detector):
resp.append(facial_area) resp.append(facial_area)
return resp return resp
class YoloDetectorClientV8n(YoloDetectorClient):
def __init__(self):
super().__init__(YoloModel.V8N)
class YoloDetectorClientV11n(YoloDetectorClient):
def __init__(self):
super().__init__(YoloModel.V11N)
class YoloDetectorClientV11s(YoloDetectorClient):
def __init__(self):
super().__init__(YoloModel.V11S)
class YoloDetectorClientV11m(YoloDetectorClient):
def __init__(self):
super().__init__(YoloModel.V11M)

View File

@ -35,8 +35,8 @@ def analyze(
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).

View File

@ -38,8 +38,8 @@ def extract_faces(
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), or base64 encoded images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv) 'centerface' or 'skip' (default is opencv)
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Default is True. Set to False to avoid the exception for low-resolution images. Default is True. Set to False to avoid the exception for low-resolution images.

View File

@ -11,7 +11,7 @@ from deepface.models.facial_recognition import (
SFace, SFace,
Dlib, Dlib,
Facenet, Facenet,
GhostFaceNet, GhostFaceNet
) )
from deepface.models.face_detection import ( from deepface.models.face_detection import (
FastMtCnn, FastMtCnn,
@ -21,7 +21,7 @@ from deepface.models.face_detection import (
Dlib as DlibDetector, Dlib as DlibDetector,
RetinaFace, RetinaFace,
Ssd, Ssd,
Yolo, Yolo as YoloFaceDetector,
YuNet, YuNet,
CenterFace, CenterFace,
) )
@ -36,10 +36,10 @@ def build_model(task: str, model_name: str) -> Any:
task (str): facial_recognition, facial_attribute, face_detector, spoofing task (str): facial_recognition, facial_attribute, face_detector, spoofing
model_name (str): model identifier model_name (str): model identifier
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
ArcFace, SFace, GhostFaceNet for face recognition ArcFace, SFace and GhostFaceNet for face recognition
- Age, Gender, Emotion, Race for facial attributes - Age, Gender, Emotion, Race for facial attributes
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yunet, - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n',
fastmtcnn or centerface for face detectors 'yolov11s', 'yolov11m', yunet, fastmtcnn or centerface for face detectors
- Fasnet for spoofing - Fasnet for spoofing
Returns: Returns:
built model class built model class
@ -59,7 +59,7 @@ def build_model(task: str, model_name: str) -> Any:
"Dlib": Dlib.DlibClient, "Dlib": Dlib.DlibClient,
"ArcFace": ArcFace.ArcFaceClient, "ArcFace": ArcFace.ArcFaceClient,
"SFace": SFace.SFaceClient, "SFace": SFace.SFaceClient,
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient, "GhostFaceNet": GhostFaceNet.GhostFaceNetClient
}, },
"spoofing": { "spoofing": {
"Fasnet": FasNet.Fasnet, "Fasnet": FasNet.Fasnet,
@ -77,7 +77,10 @@ def build_model(task: str, model_name: str) -> Any:
"dlib": DlibDetector.DlibClient, "dlib": DlibDetector.DlibClient,
"retinaface": RetinaFace.RetinaFaceClient, "retinaface": RetinaFace.RetinaFaceClient,
"mediapipe": MediaPipe.MediaPipeClient, "mediapipe": MediaPipe.MediaPipeClient,
"yolov8": Yolo.YoloClient, "yolov8": YoloFaceDetector.YoloDetectorClientV8n,
"yolov11n": YoloFaceDetector.YoloDetectorClientV11n,
"yolov11s": YoloFaceDetector.YoloDetectorClientV11s,
"yolov11m": YoloFaceDetector.YoloDetectorClientV11m,
"yunet": YuNet.YuNetClient, "yunet": YuNet.YuNetClient,
"fastmtcnn": FastMtCnn.FastMtCnnClient, "fastmtcnn": FastMtCnn.FastMtCnnClient,
"centerface": CenterFace.CenterFaceClient, "centerface": CenterFace.CenterFaceClient,

View File

@ -54,7 +54,8 @@ def find(
Default is True. Set to False to avoid the exception for low-resolution images. Default is True. Set to False to avoid the exception for low-resolution images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'. 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n', 'yolov11s',
'yolov11m', 'centerface' or 'skip'.
align (boolean): Perform alignment based on the eye positions. align (boolean): Perform alignment based on the eye positions.
@ -483,7 +484,8 @@ def find_batched(
Default is True. Set to False to avoid the exception for low-resolution images. Default is True. Set to False to avoid the exception for low-resolution images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'. 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s',
'yolov11m', 'centerface' or 'skip'.
align (boolean): Perform alignment based on the eye positions. align (boolean): Perform alignment based on the eye positions.

View File

@ -36,7 +36,8 @@ def represent(
Default is True. Set to False to avoid the exception for low-resolution images. Default is True. Set to False to avoid the exception for low-resolution images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'. 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s',
'yolov11m', 'centerface' or 'skip'.
align (boolean): Perform alignment based on the eye positions. align (boolean): Perform alignment based on the eye positions.

View File

@ -42,11 +42,11 @@ def analysis(
in the database will be considered in the decision-making process. in the database will be considered in the decision-making process.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face)
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
@ -192,8 +192,8 @@ def search_identity(
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
Returns: Returns:
@ -374,8 +374,8 @@ def grab_facial_areas(
Args: Args:
img (np.ndarray): image itself img (np.ndarray): image itself
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv). 'centerface' or 'skip' (default is opencv).
threshold (int): threshold for facial area, discard smaller ones threshold (int): threshold for facial area, discard smaller ones
Returns Returns
result (list): list of tuple with x, y, w and h coordinates result (list): list of tuple with x, y, w and h coordinates
@ -443,8 +443,8 @@ def perform_facial_recognition(
db_path (string): Path to the folder containing image files. All detected faces db_path (string): Path to the folder containing image files. All detected faces
in the database will be considered in the decision-making process. in the database will be considered in the decision-making process.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s',
(default is opencv). 'yolov11m', 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,

View File

@ -47,8 +47,8 @@ def verify(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
(default is opencv) 'centerface' or 'skip' (default is opencv)
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).

View File

@ -21,7 +21,7 @@ model_names = [
"Dlib", "Dlib",
"ArcFace", "ArcFace",
"SFace", "SFace",
"GhostFaceNet", "GhostFaceNet"
] ]
detector_backends = [ detector_backends = [
@ -34,6 +34,9 @@ detector_backends = [
"retinaface", "retinaface",
"yunet", "yunet",
"yolov8", "yolov8",
"yolov11n",
"yolov11s",
"yolov11m",
"centerface", "centerface",
] ]