mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
Merge pull request #1397 from nriviera/yolo
[FEATURE]: adding yolov11 into face detection portfolio
This commit is contained in:
commit
a402f09bc8
@ -121,7 +121,7 @@ models = [
|
||||
"ArcFace",
|
||||
"Dlib",
|
||||
"SFace",
|
||||
"GhostFaceNet",
|
||||
"GhostFaceNet"
|
||||
]
|
||||
|
||||
#face verification
|
||||
@ -223,6 +223,9 @@ backends = [
|
||||
'retinaface',
|
||||
'mediapipe',
|
||||
'yolov8',
|
||||
'yolov11s',
|
||||
'yolov11n',
|
||||
'yolov11m',
|
||||
'yunet',
|
||||
'centerface',
|
||||
]
|
||||
|
@ -54,10 +54,10 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any:
|
||||
Args:
|
||||
model_name (str): model identifier
|
||||
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
||||
ArcFace, SFace, GhostFaceNet for face recognition
|
||||
ArcFace, SFace and GhostFaceNet for face recognition
|
||||
- Age, Gender, Emotion, Race for facial attributes
|
||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yunet,
|
||||
fastmtcnn or centerface for face detectors
|
||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n,
|
||||
yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors
|
||||
- Fasnet for spoofing
|
||||
task (str): facial_recognition, facial_attribute, face_detector, spoofing
|
||||
default is facial_recognition
|
||||
@ -68,18 +68,18 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any:
|
||||
|
||||
|
||||
def verify(
|
||||
img1_path: Union[str, np.ndarray, List[float]],
|
||||
img2_path: Union[str, np.ndarray, List[float]],
|
||||
model_name: str = "VGG-Face",
|
||||
detector_backend: str = "opencv",
|
||||
distance_metric: str = "cosine",
|
||||
enforce_detection: bool = True,
|
||||
align: bool = True,
|
||||
expand_percentage: int = 0,
|
||||
normalization: str = "base",
|
||||
silent: bool = False,
|
||||
threshold: Optional[float] = None,
|
||||
anti_spoofing: bool = False,
|
||||
img1_path: Union[str, np.ndarray, List[float]],
|
||||
img2_path: Union[str, np.ndarray, List[float]],
|
||||
model_name: str = "VGG-Face",
|
||||
detector_backend: str = "opencv",
|
||||
distance_metric: str = "cosine",
|
||||
enforce_detection: bool = True,
|
||||
align: bool = True,
|
||||
expand_percentage: int = 0,
|
||||
normalization: str = "base",
|
||||
silent: bool = False,
|
||||
threshold: Optional[float] = None,
|
||||
anti_spoofing: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify if an image pair represents the same person or different persons.
|
||||
@ -96,8 +96,8 @@ def verify(
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
@ -164,14 +164,14 @@ def verify(
|
||||
|
||||
|
||||
def analyze(
|
||||
img_path: Union[str, np.ndarray],
|
||||
actions: Union[tuple, list] = ("emotion", "age", "gender", "race"),
|
||||
enforce_detection: bool = True,
|
||||
detector_backend: str = "opencv",
|
||||
align: bool = True,
|
||||
expand_percentage: int = 0,
|
||||
silent: bool = False,
|
||||
anti_spoofing: bool = False,
|
||||
img_path: Union[str, np.ndarray],
|
||||
actions: Union[tuple, list] = ("emotion", "age", "gender", "race"),
|
||||
enforce_detection: bool = True,
|
||||
detector_backend: str = "opencv",
|
||||
align: bool = True,
|
||||
expand_percentage: int = 0,
|
||||
silent: bool = False,
|
||||
anti_spoofing: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Analyze facial attributes such as age, gender, emotion, and race in the provided image.
|
||||
@ -187,8 +187,8 @@ def analyze(
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
@ -263,20 +263,20 @@ def analyze(
|
||||
|
||||
|
||||
def find(
|
||||
img_path: Union[str, np.ndarray],
|
||||
db_path: str,
|
||||
model_name: str = "VGG-Face",
|
||||
distance_metric: str = "cosine",
|
||||
enforce_detection: bool = True,
|
||||
detector_backend: str = "opencv",
|
||||
align: bool = True,
|
||||
expand_percentage: int = 0,
|
||||
threshold: Optional[float] = None,
|
||||
normalization: str = "base",
|
||||
silent: bool = False,
|
||||
refresh_database: bool = True,
|
||||
anti_spoofing: bool = False,
|
||||
batched: bool = False,
|
||||
img_path: Union[str, np.ndarray],
|
||||
db_path: str,
|
||||
model_name: str = "VGG-Face",
|
||||
distance_metric: str = "cosine",
|
||||
enforce_detection: bool = True,
|
||||
detector_backend: str = "opencv",
|
||||
align: bool = True,
|
||||
expand_percentage: int = 0,
|
||||
threshold: Optional[float] = None,
|
||||
normalization: str = "base",
|
||||
silent: bool = False,
|
||||
refresh_database: bool = True,
|
||||
anti_spoofing: bool = False,
|
||||
batched: bool = False,
|
||||
) -> Union[List[pd.DataFrame], List[List[Dict[str, Any]]]]:
|
||||
"""
|
||||
Identify individuals in a database
|
||||
@ -298,8 +298,8 @@ def find(
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||
|
||||
@ -369,15 +369,15 @@ def find(
|
||||
|
||||
|
||||
def represent(
|
||||
img_path: Union[str, np.ndarray],
|
||||
model_name: str = "VGG-Face",
|
||||
enforce_detection: bool = True,
|
||||
detector_backend: str = "opencv",
|
||||
align: bool = True,
|
||||
expand_percentage: int = 0,
|
||||
normalization: str = "base",
|
||||
anti_spoofing: bool = False,
|
||||
max_faces: Optional[int] = None,
|
||||
img_path: Union[str, np.ndarray],
|
||||
model_name: str = "VGG-Face",
|
||||
enforce_detection: bool = True,
|
||||
detector_backend: str = "opencv",
|
||||
align: bool = True,
|
||||
expand_percentage: int = 0,
|
||||
normalization: str = "base",
|
||||
anti_spoofing: bool = False,
|
||||
max_faces: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Represent facial images as multi-dimensional vector embeddings.
|
||||
@ -396,8 +396,8 @@ def represent(
|
||||
(default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||
|
||||
@ -441,15 +441,15 @@ def represent(
|
||||
|
||||
|
||||
def stream(
|
||||
db_path: str = "",
|
||||
model_name: str = "VGG-Face",
|
||||
detector_backend: str = "opencv",
|
||||
distance_metric: str = "cosine",
|
||||
enable_face_analysis: bool = True,
|
||||
source: Any = 0,
|
||||
time_threshold: int = 5,
|
||||
frame_threshold: int = 5,
|
||||
anti_spoofing: bool = False,
|
||||
db_path: str = "",
|
||||
model_name: str = "VGG-Face",
|
||||
detector_backend: str = "opencv",
|
||||
distance_metric: str = "cosine",
|
||||
enable_face_analysis: bool = True,
|
||||
source: Any = 0,
|
||||
time_threshold: int = 5,
|
||||
frame_threshold: int = 5,
|
||||
anti_spoofing: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Run real time face recognition and facial attribute analysis
|
||||
@ -462,8 +462,8 @@ def stream(
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
@ -499,15 +499,15 @@ def stream(
|
||||
|
||||
|
||||
def extract_faces(
|
||||
img_path: Union[str, np.ndarray],
|
||||
detector_backend: str = "opencv",
|
||||
enforce_detection: bool = True,
|
||||
align: bool = True,
|
||||
expand_percentage: int = 0,
|
||||
grayscale: bool = False,
|
||||
color_face: str = "rgb",
|
||||
normalize_face: bool = True,
|
||||
anti_spoofing: bool = False,
|
||||
img_path: Union[str, np.ndarray],
|
||||
detector_backend: str = "opencv",
|
||||
enforce_detection: bool = True,
|
||||
align: bool = True,
|
||||
expand_percentage: int = 0,
|
||||
grayscale: bool = False,
|
||||
color_face: str = "rgb",
|
||||
normalize_face: bool = True,
|
||||
anti_spoofing: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract faces from a given image
|
||||
@ -517,8 +517,8 @@ def extract_faces(
|
||||
as a string, numpy array (BGR), or base64 encoded images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
@ -584,11 +584,11 @@ def cli() -> None:
|
||||
|
||||
|
||||
def detectFace(
|
||||
img_path: Union[str, np.ndarray],
|
||||
target_size: tuple = (224, 224),
|
||||
detector_backend: str = "opencv",
|
||||
enforce_detection: bool = True,
|
||||
align: bool = True,
|
||||
img_path: Union[str, np.ndarray],
|
||||
target_size: tuple = (224, 224),
|
||||
detector_backend: str = "opencv",
|
||||
enforce_detection: bool = True,
|
||||
align: bool = True,
|
||||
) -> Union[np.ndarray, None]:
|
||||
"""
|
||||
Deprecated face detection function. Use extract_faces for same functionality.
|
||||
@ -601,8 +601,8 @@ def detectFace(
|
||||
added to resize the image (default is (224, 224)).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
@ -128,8 +128,9 @@ def download_all_models_in_one_shot() -> None:
|
||||
WEIGHTS_URL as SSD_WEIGHTS,
|
||||
)
|
||||
from deepface.models.face_detection.Yolo import (
|
||||
WEIGHT_URL as YOLOV8_WEIGHTS,
|
||||
WEIGHT_NAME as YOLOV8_WEIGHT_NAME,
|
||||
WEIGHT_URLS as YOLO_WEIGHTS,
|
||||
WEIGHT_NAMES as YOLO_WEIGHT_NAMES,
|
||||
YoloModel
|
||||
)
|
||||
from deepface.models.face_detection.YuNet import WEIGHTS_URL as YUNET_WEIGHTS
|
||||
from deepface.models.face_detection.Dlib import WEIGHTS_URL as DLIB_FD_WEIGHTS
|
||||
@ -162,8 +163,20 @@ def download_all_models_in_one_shot() -> None:
|
||||
SSD_MODEL,
|
||||
SSD_WEIGHTS,
|
||||
{
|
||||
"filename": YOLOV8_WEIGHT_NAME,
|
||||
"url": YOLOV8_WEIGHTS,
|
||||
"filename": YOLO_WEIGHT_NAMES[YoloModel.V8N.value],
|
||||
"url": YOLO_WEIGHTS[YoloModel.V8N.value],
|
||||
},
|
||||
{
|
||||
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11N.value],
|
||||
"url": YOLO_WEIGHTS[YoloModel.V11N.value],
|
||||
},
|
||||
{
|
||||
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11S.value],
|
||||
"url": YOLO_WEIGHTS[YoloModel.V11S.value],
|
||||
},
|
||||
{
|
||||
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11M.value],
|
||||
"url": YOLO_WEIGHTS[YoloModel.V11M.value],
|
||||
},
|
||||
YUNET_WEIGHTS,
|
||||
DLIB_FD_WEIGHTS,
|
||||
|
@ -1,29 +1,45 @@
|
||||
# built-in dependencies
|
||||
import os
|
||||
from typing import Any, List
|
||||
from typing import List, Any
|
||||
from enum import Enum
|
||||
|
||||
# 3rd party dependencies
|
||||
import numpy as np
|
||||
|
||||
# project dependencies
|
||||
from deepface.models.Detector import Detector, FacialAreaRegion
|
||||
from deepface.commons import weight_utils
|
||||
from deepface.commons.logger import Logger
|
||||
from deepface.commons import weight_utils
|
||||
|
||||
logger = Logger()
|
||||
|
||||
|
||||
class YoloModel(Enum):
|
||||
V8N = 0
|
||||
V11N = 1
|
||||
V11S = 2
|
||||
V11M = 3
|
||||
|
||||
|
||||
# Model's weights paths
|
||||
WEIGHT_NAME = "yolov8n-face.pt"
|
||||
WEIGHT_NAMES = ["yolov8n-face.pt",
|
||||
"yolov11n-face.pt",
|
||||
"yolov11s-face.pt",
|
||||
"yolov11m-face.pt"]
|
||||
|
||||
# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB
|
||||
WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb"
|
||||
WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb",
|
||||
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt",
|
||||
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11s-face.pt",
|
||||
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"]
|
||||
|
||||
|
||||
class YoloClient(Detector):
|
||||
def __init__(self):
|
||||
self.model = self.build_model()
|
||||
class YoloDetectorClient(Detector):
|
||||
def __init__(self, model: YoloModel):
|
||||
super().__init__()
|
||||
self.model = self.build_model(model)
|
||||
|
||||
def build_model(self) -> Any:
|
||||
def build_model(self, model: YoloModel) -> Any:
|
||||
"""
|
||||
Build a yolo detector model
|
||||
Returns:
|
||||
@ -40,7 +56,7 @@ class YoloClient(Detector):
|
||||
) from e
|
||||
|
||||
weight_file = weight_utils.download_weights_if_necessary(
|
||||
file_name=WEIGHT_NAME, source_url=WEIGHT_URL
|
||||
file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value]
|
||||
)
|
||||
|
||||
# Return face_detector
|
||||
@ -69,21 +85,27 @@ class YoloClient(Detector):
|
||||
# For each face, extract the bounding box, the landmarks and confidence
|
||||
for result in results:
|
||||
|
||||
if result.boxes is None or result.keypoints is None:
|
||||
if result.boxes is None:
|
||||
continue
|
||||
|
||||
# Extract the bounding box and the confidence
|
||||
x, y, w, h = result.boxes.xywh.tolist()[0]
|
||||
confidence = result.boxes.conf.tolist()[0]
|
||||
|
||||
# right_eye_conf = result.keypoints.conf[0][0]
|
||||
# left_eye_conf = result.keypoints.conf[0][1]
|
||||
right_eye = result.keypoints.xy[0][0].tolist()
|
||||
left_eye = result.keypoints.xy[0][1].tolist()
|
||||
right_eye = None
|
||||
left_eye = None
|
||||
|
||||
# eyes are list of float, need to cast them tuple of int
|
||||
left_eye = tuple(int(i) for i in left_eye)
|
||||
right_eye = tuple(int(i) for i in right_eye)
|
||||
# yolo-facev8 is detecting eyes through keypoints,
|
||||
# while for v11 keypoints are always None
|
||||
if result.keypoints is not None:
|
||||
# right_eye_conf = result.keypoints.conf[0][0]
|
||||
# left_eye_conf = result.keypoints.conf[0][1]
|
||||
right_eye = result.keypoints.xy[0][0].tolist()
|
||||
left_eye = result.keypoints.xy[0][1].tolist()
|
||||
|
||||
# eyes are list of float, need to cast them tuple of int
|
||||
left_eye = tuple(int(i) for i in left_eye)
|
||||
right_eye = tuple(int(i) for i in right_eye)
|
||||
|
||||
x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h)
|
||||
facial_area = FacialAreaRegion(
|
||||
@ -98,3 +120,23 @@ class YoloClient(Detector):
|
||||
resp.append(facial_area)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
class YoloDetectorClientV8n(YoloDetectorClient):
|
||||
def __init__(self):
|
||||
super().__init__(YoloModel.V8N)
|
||||
|
||||
|
||||
class YoloDetectorClientV11n(YoloDetectorClient):
|
||||
def __init__(self):
|
||||
super().__init__(YoloModel.V11N)
|
||||
|
||||
|
||||
class YoloDetectorClientV11s(YoloDetectorClient):
|
||||
def __init__(self):
|
||||
super().__init__(YoloModel.V11S)
|
||||
|
||||
|
||||
class YoloDetectorClientV11m(YoloDetectorClient):
|
||||
def __init__(self):
|
||||
super().__init__(YoloModel.V11M)
|
||||
|
@ -35,8 +35,8 @@ def analyze(
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
|
@ -38,8 +38,8 @@ def extract_faces(
|
||||
as a string, numpy array (BGR), or base64 encoded images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv)
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv)
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||
|
@ -11,7 +11,7 @@ from deepface.models.facial_recognition import (
|
||||
SFace,
|
||||
Dlib,
|
||||
Facenet,
|
||||
GhostFaceNet,
|
||||
GhostFaceNet
|
||||
)
|
||||
from deepface.models.face_detection import (
|
||||
FastMtCnn,
|
||||
@ -21,7 +21,7 @@ from deepface.models.face_detection import (
|
||||
Dlib as DlibDetector,
|
||||
RetinaFace,
|
||||
Ssd,
|
||||
Yolo,
|
||||
Yolo as YoloFaceDetector,
|
||||
YuNet,
|
||||
CenterFace,
|
||||
)
|
||||
@ -36,10 +36,10 @@ def build_model(task: str, model_name: str) -> Any:
|
||||
task (str): facial_recognition, facial_attribute, face_detector, spoofing
|
||||
model_name (str): model identifier
|
||||
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
||||
ArcFace, SFace, GhostFaceNet for face recognition
|
||||
ArcFace, SFace and GhostFaceNet for face recognition
|
||||
- Age, Gender, Emotion, Race for facial attributes
|
||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yunet,
|
||||
fastmtcnn or centerface for face detectors
|
||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n',
|
||||
'yolov11s', 'yolov11m', yunet, fastmtcnn or centerface for face detectors
|
||||
- Fasnet for spoofing
|
||||
Returns:
|
||||
built model class
|
||||
@ -59,7 +59,7 @@ def build_model(task: str, model_name: str) -> Any:
|
||||
"Dlib": Dlib.DlibClient,
|
||||
"ArcFace": ArcFace.ArcFaceClient,
|
||||
"SFace": SFace.SFaceClient,
|
||||
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient,
|
||||
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient
|
||||
},
|
||||
"spoofing": {
|
||||
"Fasnet": FasNet.Fasnet,
|
||||
@ -77,7 +77,10 @@ def build_model(task: str, model_name: str) -> Any:
|
||||
"dlib": DlibDetector.DlibClient,
|
||||
"retinaface": RetinaFace.RetinaFaceClient,
|
||||
"mediapipe": MediaPipe.MediaPipeClient,
|
||||
"yolov8": Yolo.YoloClient,
|
||||
"yolov8": YoloFaceDetector.YoloDetectorClientV8n,
|
||||
"yolov11n": YoloFaceDetector.YoloDetectorClientV11n,
|
||||
"yolov11s": YoloFaceDetector.YoloDetectorClientV11s,
|
||||
"yolov11m": YoloFaceDetector.YoloDetectorClientV11m,
|
||||
"yunet": YuNet.YuNetClient,
|
||||
"fastmtcnn": FastMtCnn.FastMtCnnClient,
|
||||
"centerface": CenterFace.CenterFaceClient,
|
||||
|
@ -54,7 +54,8 @@ def find(
|
||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n', 'yolov11s',
|
||||
'yolov11m', 'centerface' or 'skip'.
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions.
|
||||
|
||||
@ -483,7 +484,8 @@ def find_batched(
|
||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s',
|
||||
'yolov11m', 'centerface' or 'skip'.
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions.
|
||||
|
||||
|
@ -36,7 +36,8 @@ def represent(
|
||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s',
|
||||
'yolov11m', 'centerface' or 'skip'.
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions.
|
||||
|
||||
|
@ -42,11 +42,11 @@ def analysis(
|
||||
in the database will be considered in the decision-making process.
|
||||
|
||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face)
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
@ -192,8 +192,8 @@ def search_identity(
|
||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
Returns:
|
||||
@ -374,8 +374,8 @@ def grab_facial_areas(
|
||||
Args:
|
||||
img (np.ndarray): image itself
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
threshold (int): threshold for facial area, discard smaller ones
|
||||
Returns
|
||||
result (list): list of tuple with x, y, w and h coordinates
|
||||
@ -443,8 +443,8 @@ def perform_facial_recognition(
|
||||
db_path (string): Path to the folder containing image files. All detected faces
|
||||
in the database will be considered in the decision-making process.
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s',
|
||||
'yolov11m', 'centerface' or 'skip' (default is opencv).
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||
|
@ -47,8 +47,8 @@ def verify(
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
(default is opencv)
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv)
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
|
@ -21,7 +21,7 @@ model_names = [
|
||||
"Dlib",
|
||||
"ArcFace",
|
||||
"SFace",
|
||||
"GhostFaceNet",
|
||||
"GhostFaceNet"
|
||||
]
|
||||
|
||||
detector_backends = [
|
||||
@ -34,6 +34,9 @@ detector_backends = [
|
||||
"retinaface",
|
||||
"yunet",
|
||||
"yolov8",
|
||||
"yolov11n",
|
||||
"yolov11s",
|
||||
"yolov11m",
|
||||
"centerface",
|
||||
]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user