mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 12:05:22 +00:00
yolov11n and yolov11m added to model selection
This commit is contained in:
parent
e73bdb8138
commit
2c2dc7b1f0
@ -223,6 +223,8 @@ backends = [
|
||||
'retinaface',
|
||||
'mediapipe',
|
||||
'yolov8',
|
||||
'yolov11n',
|
||||
'yolov11m',
|
||||
'yunet',
|
||||
'centerface',
|
||||
]
|
||||
|
2
benchmarks/Evaluate-Results.ipynb
vendored
2
benchmarks/Evaluate-Results.ipynb
vendored
@ -30,7 +30,7 @@
|
||||
"source": [
|
||||
"alignment = [False, True]\n",
|
||||
"models = [\"Facenet512\", \"Facenet\", \"VGG-Face\", \"ArcFace\", \"Dlib\", \"GhostFaceNet\", \"SFace\", \"OpenFace\", \"DeepFace\", \"DeepID\"]\n",
|
||||
"detectors = [\"retinaface\", \"mtcnn\", \"fastmtcnn\", \"dlib\", \"yolov8\", \"yunet\", \"centerface\", \"mediapipe\", \"ssd\", \"opencv\", \"skip\"]\n",
|
||||
"detectors = [\"retinaface\", \"mtcnn\", \"fastmtcnn\", \"dlib\", \"yolov8\", \"yolov11n\", \"yolov11m\", \"yunet\", \"centerface\", \"mediapipe\", \"ssd\", \"opencv\", \"skip\"]\n",
|
||||
"distance_metrics = [\"euclidean\", \"euclidean_l2\", \"cosine\"]"
|
||||
]
|
||||
},
|
||||
|
@ -56,7 +56,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any:
|
||||
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
||||
ArcFace, SFace, GhostFaceNet for face recognition
|
||||
- Age, Gender, Emotion, Race for facial attributes
|
||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yunet,
|
||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11m', yunet,
|
||||
fastmtcnn or centerface for face detectors
|
||||
- Fasnet for spoofing
|
||||
task (str): facial_recognition, facial_attribute, face_detector, spoofing
|
||||
@ -96,7 +96,7 @@ def verify(
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
@ -187,7 +187,7 @@ def analyze(
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
@ -298,7 +298,7 @@ def find(
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||
@ -396,7 +396,7 @@ def represent(
|
||||
(default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||
@ -462,7 +462,7 @@ def stream(
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
@ -517,7 +517,7 @@ def extract_faces(
|
||||
as a string, numpy array (BGR), or base64 encoded images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
@ -601,7 +601,7 @@ def detectFace(
|
||||
added to resize the image (default is (224, 224)).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
|
@ -128,8 +128,9 @@ def download_all_models_in_one_shot() -> None:
|
||||
WEIGHTS_URL as SSD_WEIGHTS,
|
||||
)
|
||||
from deepface.models.face_detection.Yolo import (
|
||||
WEIGHT_URL as YOLOV8_WEIGHTS,
|
||||
WEIGHT_NAME as YOLOV8_WEIGHT_NAME,
|
||||
WEIGHT_URLS as YOLO_WEIGHTS,
|
||||
WEIGHT_NAMES as YOLO_WEIGHT_NAMES,
|
||||
YoloModel
|
||||
)
|
||||
from deepface.models.face_detection.YuNet import WEIGHTS_URL as YUNET_WEIGHTS
|
||||
from deepface.models.face_detection.Dlib import WEIGHTS_URL as DLIB_FD_WEIGHTS
|
||||
@ -162,8 +163,16 @@ def download_all_models_in_one_shot() -> None:
|
||||
SSD_MODEL,
|
||||
SSD_WEIGHTS,
|
||||
{
|
||||
"filename": YOLOV8_WEIGHT_NAME,
|
||||
"url": YOLOV8_WEIGHTS,
|
||||
"filename": YOLO_WEIGHT_NAMES[YoloModel.V8N.value],
|
||||
"url": YOLO_WEIGHTS[YoloModel.V8N.value],
|
||||
},
|
||||
{
|
||||
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11N.value],
|
||||
"url": YOLO_WEIGHTS[YoloModel.V11N.value],
|
||||
},
|
||||
{
|
||||
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11M.value],
|
||||
"url": YOLO_WEIGHTS[YoloModel.V11M.value],
|
||||
},
|
||||
YUNET_WEIGHTS,
|
||||
DLIB_FD_WEIGHTS,
|
||||
|
@ -1,6 +1,7 @@
|
||||
# built-in dependencies
|
||||
import os
|
||||
from typing import Any, List
|
||||
from enum import Enum
|
||||
|
||||
# 3rd party dependencies
|
||||
import numpy as np
|
||||
@ -13,17 +14,27 @@ from deepface.commons.logger import Logger
|
||||
logger = Logger()
|
||||
|
||||
# Model's weights paths
|
||||
WEIGHT_NAME = "yolov8n-face.pt"
|
||||
WEIGHT_NAMES = ["yolov8n-face.pt",
|
||||
"yolov11n-face.pt",
|
||||
"yolov11m-face.pt"]
|
||||
|
||||
# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB
|
||||
WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb"
|
||||
WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb",
|
||||
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt",
|
||||
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"]
|
||||
|
||||
|
||||
class YoloModel(Enum):
|
||||
V8N = 0
|
||||
V11N = 1
|
||||
V11M = 2
|
||||
|
||||
|
||||
class YoloClient(Detector):
|
||||
def __init__(self):
|
||||
self.model = self.build_model()
|
||||
def __init__(self, model: YoloModel):
|
||||
self.model = self.build_model(model)
|
||||
|
||||
def build_model(self) -> Any:
|
||||
def build_model(self, model: YoloModel) -> Any:
|
||||
"""
|
||||
Build a yolo detector model
|
||||
Returns:
|
||||
@ -40,7 +51,7 @@ class YoloClient(Detector):
|
||||
) from e
|
||||
|
||||
weight_file = weight_utils.download_weights_if_necessary(
|
||||
file_name=WEIGHT_NAME, source_url=WEIGHT_URL
|
||||
file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value]
|
||||
)
|
||||
|
||||
# Return face_detector
|
||||
@ -98,3 +109,18 @@ class YoloClient(Detector):
|
||||
resp.append(facial_area)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
class YoloClientV8n(YoloClient):
|
||||
def __init__(self):
|
||||
super().__init__(YoloModel.V8N)
|
||||
|
||||
|
||||
class YoloClientV11n(YoloClient):
|
||||
def __init__(self):
|
||||
super().__init__(YoloModel.V11N)
|
||||
|
||||
|
||||
class YoloClientV11m(YoloClient):
|
||||
def __init__(self):
|
||||
super().__init__(YoloModel.V11M)
|
||||
|
@ -35,7 +35,7 @@ def analyze(
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
|
@ -38,7 +38,7 @@ def extract_faces(
|
||||
as a string, numpy array (BGR), or base64 encoded images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv)
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
|
@ -38,7 +38,7 @@ def build_model(task: str, model_name: str) -> Any:
|
||||
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
||||
ArcFace, SFace, GhostFaceNet for face recognition
|
||||
- Age, Gender, Emotion, Race for facial attributes
|
||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yunet,
|
||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11m', yunet,
|
||||
fastmtcnn or centerface for face detectors
|
||||
- Fasnet for spoofing
|
||||
Returns:
|
||||
@ -77,7 +77,9 @@ def build_model(task: str, model_name: str) -> Any:
|
||||
"dlib": DlibDetector.DlibClient,
|
||||
"retinaface": RetinaFace.RetinaFaceClient,
|
||||
"mediapipe": MediaPipe.MediaPipeClient,
|
||||
"yolov8": Yolo.YoloClient,
|
||||
"yolov8": Yolo.YoloClientV8n,
|
||||
"yolov11n": Yolo.YoloClientV11n,
|
||||
"yolov11m": Yolo.YoloClientV11m,
|
||||
"yunet": YuNet.YuNetClient,
|
||||
"fastmtcnn": FastMtCnn.FastMtCnnClient,
|
||||
"centerface": CenterFace.CenterFaceClient,
|
||||
|
@ -54,7 +54,7 @@ def find(
|
||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n','yolov11m', 'centerface' or 'skip'.
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions.
|
||||
|
||||
@ -483,7 +483,7 @@ def find_batched(
|
||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'.
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions.
|
||||
|
||||
|
@ -36,7 +36,7 @@ def represent(
|
||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'.
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions.
|
||||
|
||||
|
@ -45,7 +45,7 @@ def analysis(
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
@ -192,7 +192,7 @@ def search_identity(
|
||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
@ -374,7 +374,7 @@ def grab_facial_areas(
|
||||
Args:
|
||||
img (np.ndarray): image itself
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
threshold (int): threshold for facial area, discard smaller ones
|
||||
Returns
|
||||
@ -443,7 +443,7 @@ def perform_facial_recognition(
|
||||
db_path (string): Path to the folder containing image files. All detected faces
|
||||
in the database will be considered in the decision-making process.
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv).
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
|
@ -47,7 +47,7 @@ def verify(
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
||||
(default is opencv)
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
|
@ -34,6 +34,8 @@ detector_backends = [
|
||||
"retinaface",
|
||||
"yunet",
|
||||
"yolov8",
|
||||
"yolov11n",
|
||||
"yolov11m",
|
||||
"centerface",
|
||||
]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user