mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 19:45:21 +00:00
yolov11n and yolov11m added to model selection
This commit is contained in:
parent
2c2dc7b1f0
commit
38261e07e5
4
benchmarks/Evaluate-Results.ipynb
vendored
4
benchmarks/Evaluate-Results.ipynb
vendored
@ -29,8 +29,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"alignment = [False, True]\n",
|
"alignment = [False, True]\n",
|
||||||
"models = [\"Facenet512\", \"Facenet\", \"VGG-Face\", \"ArcFace\", \"Dlib\", \"GhostFaceNet\", \"SFace\", \"OpenFace\", \"DeepFace\", \"DeepID\"]\n",
|
"models = [\"Facenet512\", \"Facenet\", \"VGG-Face\", \"ArcFace\", \"Dlib\", \"GhostFaceNet\", \"SFace\", \"OpenFace\", \"DeepFace\", \"DeepID\", \"yolov8\", \"yolov11n\", \"yolov11s\", \"yolov11m\"]\n",
|
||||||
"detectors = [\"retinaface\", \"mtcnn\", \"fastmtcnn\", \"dlib\", \"yolov8\", \"yolov11n\", \"yolov11m\", \"yunet\", \"centerface\", \"mediapipe\", \"ssd\", \"opencv\", \"skip\"]\n",
|
"detectors = [\"retinaface\", \"mtcnn\", \"fastmtcnn\", \"dlib\", \"yolov8\", \"yolov11n\", \"yolov11s\", \"yolov11m\", \"yunet\", \"centerface\", \"mediapipe\", \"ssd\", \"opencv\", \"skip\"]\n",
|
||||||
"distance_metrics = [\"euclidean\", \"euclidean_l2\", \"cosine\"]"
|
"distance_metrics = [\"euclidean\", \"euclidean_l2\", \"cosine\"]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -56,7 +56,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any:
|
|||||||
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
||||||
ArcFace, SFace, GhostFaceNet for face recognition
|
ArcFace, SFace, GhostFaceNet for face recognition
|
||||||
- Age, Gender, Emotion, Race for facial attributes
|
- Age, Gender, Emotion, Race for facial attributes
|
||||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11m', yunet,
|
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11s','yolov11m', yunet,
|
||||||
fastmtcnn or centerface for face detectors
|
fastmtcnn or centerface for face detectors
|
||||||
- Fasnet for spoofing
|
- Fasnet for spoofing
|
||||||
task (str): facial_recognition, facial_attribute, face_detector, spoofing
|
task (str): facial_recognition, facial_attribute, face_detector, spoofing
|
||||||
@ -96,7 +96,7 @@ def verify(
|
|||||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
|
|
||||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||||
@ -187,7 +187,7 @@ def analyze(
|
|||||||
Set to False to avoid the exception for low-resolution images (default is True).
|
Set to False to avoid the exception for low-resolution images (default is True).
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
|
|
||||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||||
@ -298,7 +298,7 @@ def find(
|
|||||||
Set to False to avoid the exception for low-resolution images (default is True).
|
Set to False to avoid the exception for low-resolution images (default is True).
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
|
|
||||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||||
@ -396,7 +396,7 @@ def represent(
|
|||||||
(default is True).
|
(default is True).
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
|
|
||||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||||
@ -462,7 +462,7 @@ def stream(
|
|||||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
|
|
||||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||||
@ -517,7 +517,7 @@ def extract_faces(
|
|||||||
as a string, numpy array (BGR), or base64 encoded images.
|
as a string, numpy array (BGR), or base64 encoded images.
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
|
|
||||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||||
@ -601,7 +601,7 @@ def detectFace(
|
|||||||
added to resize the image (default is (224, 224)).
|
added to resize the image (default is (224, 224)).
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
|
|
||||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||||
|
@ -127,7 +127,7 @@ def download_all_models_in_one_shot() -> None:
|
|||||||
MODEL_URL as SSD_MODEL,
|
MODEL_URL as SSD_MODEL,
|
||||||
WEIGHTS_URL as SSD_WEIGHTS,
|
WEIGHTS_URL as SSD_WEIGHTS,
|
||||||
)
|
)
|
||||||
from deepface.models.face_detection.Yolo import (
|
from deepface.models.YoloModel import (
|
||||||
WEIGHT_URLS as YOLO_WEIGHTS,
|
WEIGHT_URLS as YOLO_WEIGHTS,
|
||||||
WEIGHT_NAMES as YOLO_WEIGHT_NAMES,
|
WEIGHT_NAMES as YOLO_WEIGHT_NAMES,
|
||||||
YoloModel
|
YoloModel
|
||||||
@ -170,6 +170,10 @@ def download_all_models_in_one_shot() -> None:
|
|||||||
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11N.value],
|
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11N.value],
|
||||||
"url": YOLO_WEIGHTS[YoloModel.V11N.value],
|
"url": YOLO_WEIGHTS[YoloModel.V11N.value],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11S.value],
|
||||||
|
"url": YOLO_WEIGHTS[YoloModel.V11S.value],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11M.value],
|
"filename": YOLO_WEIGHT_NAMES[YoloModel.V11M.value],
|
||||||
"url": YOLO_WEIGHTS[YoloModel.V11M.value],
|
"url": YOLO_WEIGHTS[YoloModel.V11M.value],
|
||||||
|
37
deepface/models/YoloClientBase.py
Normal file
37
deepface/models/YoloClientBase.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# built-in dependencies
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# project dependencies
|
||||||
|
from deepface.models.YoloModel import YoloModel, WEIGHT_URLS, WEIGHT_NAMES
|
||||||
|
from deepface.commons import weight_utils
|
||||||
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
|
logger = Logger()
|
||||||
|
|
||||||
|
|
||||||
|
class YoloClientBase:
|
||||||
|
def __init__(self, model: YoloModel):
|
||||||
|
self.model = self.build_model(model)
|
||||||
|
|
||||||
|
def build_model(self, model: YoloModel) -> Any:
|
||||||
|
"""
|
||||||
|
Build a yolo detector model
|
||||||
|
Returns:
|
||||||
|
model (Any)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import the optional Ultralytics YOLO model
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Yolo is an optional detector, ensure the library is installed. "
|
||||||
|
"Please install using 'pip install ultralytics'"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
weight_file = weight_utils.download_weights_if_necessary(
|
||||||
|
file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return face_detector
|
||||||
|
return YOLO(weight_file)
|
21
deepface/models/YoloModel.py
Normal file
21
deepface/models/YoloModel.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class YoloModel(Enum):
|
||||||
|
V8N = 0
|
||||||
|
V11N = 1
|
||||||
|
V11S = 2
|
||||||
|
V11M = 3
|
||||||
|
|
||||||
|
|
||||||
|
# Model's weights paths
|
||||||
|
WEIGHT_NAMES = ["yolov8n-face.pt",
|
||||||
|
"yolov11n-face.pt",
|
||||||
|
"yolov11s-face.pt",
|
||||||
|
"yolov11m-face.pt"]
|
||||||
|
|
||||||
|
# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB
|
||||||
|
WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb",
|
||||||
|
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt",
|
||||||
|
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11s-face.pt",
|
||||||
|
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"]
|
@ -1,61 +1,22 @@
|
|||||||
# built-in dependencies
|
# built-in dependencies
|
||||||
import os
|
import os
|
||||||
from typing import Any, List
|
from typing import List
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# project dependencies
|
# project dependencies
|
||||||
|
from deepface.models.YoloClientBase import YoloClientBase
|
||||||
|
from deepface.models.YoloModel import YoloModel
|
||||||
from deepface.models.Detector import Detector, FacialAreaRegion
|
from deepface.models.Detector import Detector, FacialAreaRegion
|
||||||
from deepface.commons import weight_utils
|
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
logger = Logger()
|
logger = Logger()
|
||||||
|
|
||||||
# Model's weights paths
|
|
||||||
WEIGHT_NAMES = ["yolov8n-face.pt",
|
|
||||||
"yolov11n-face.pt",
|
|
||||||
"yolov11m-face.pt"]
|
|
||||||
|
|
||||||
# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB
|
class YoloDetectorClient(YoloClientBase, Detector):
|
||||||
WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb",
|
|
||||||
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt",
|
|
||||||
"https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"]
|
|
||||||
|
|
||||||
|
|
||||||
class YoloModel(Enum):
|
|
||||||
V8N = 0
|
|
||||||
V11N = 1
|
|
||||||
V11M = 2
|
|
||||||
|
|
||||||
|
|
||||||
class YoloClient(Detector):
|
|
||||||
def __init__(self, model: YoloModel):
|
def __init__(self, model: YoloModel):
|
||||||
self.model = self.build_model(model)
|
super().__init__(model)
|
||||||
|
|
||||||
def build_model(self, model: YoloModel) -> Any:
|
|
||||||
"""
|
|
||||||
Build a yolo detector model
|
|
||||||
Returns:
|
|
||||||
model (Any)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Import the optional Ultralytics YOLO model
|
|
||||||
try:
|
|
||||||
from ultralytics import YOLO
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Yolo is an optional detector, ensure the library is installed. "
|
|
||||||
"Please install using 'pip install ultralytics'"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
weight_file = weight_utils.download_weights_if_necessary(
|
|
||||||
file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return face_detector
|
|
||||||
return YOLO(weight_file)
|
|
||||||
|
|
||||||
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
|
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
|
||||||
"""
|
"""
|
||||||
@ -80,21 +41,24 @@ class YoloClient(Detector):
|
|||||||
# For each face, extract the bounding box, the landmarks and confidence
|
# For each face, extract the bounding box, the landmarks and confidence
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|
||||||
if result.boxes is None or result.keypoints is None:
|
if result.boxes is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Extract the bounding box and the confidence
|
# Extract the bounding box and the confidence
|
||||||
x, y, w, h = result.boxes.xywh.tolist()[0]
|
x, y, w, h = result.boxes.xywh.tolist()[0]
|
||||||
confidence = result.boxes.conf.tolist()[0]
|
confidence = result.boxes.conf.tolist()[0]
|
||||||
|
|
||||||
# right_eye_conf = result.keypoints.conf[0][0]
|
right_eye = None
|
||||||
# left_eye_conf = result.keypoints.conf[0][1]
|
left_eye = None
|
||||||
right_eye = result.keypoints.xy[0][0].tolist()
|
if result.keypoints is not None:
|
||||||
left_eye = result.keypoints.xy[0][1].tolist()
|
# right_eye_conf = result.keypoints.conf[0][0]
|
||||||
|
# left_eye_conf = result.keypoints.conf[0][1]
|
||||||
|
right_eye = result.keypoints.xy[0][0].tolist()
|
||||||
|
left_eye = result.keypoints.xy[0][1].tolist()
|
||||||
|
|
||||||
# eyes are list of float, need to cast them tuple of int
|
# eyes are list of float, need to cast them tuple of int
|
||||||
left_eye = tuple(int(i) for i in left_eye)
|
left_eye = tuple(int(i) for i in left_eye)
|
||||||
right_eye = tuple(int(i) for i in right_eye)
|
right_eye = tuple(int(i) for i in right_eye)
|
||||||
|
|
||||||
x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h)
|
x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h)
|
||||||
facial_area = FacialAreaRegion(
|
facial_area = FacialAreaRegion(
|
||||||
@ -111,16 +75,21 @@ class YoloClient(Detector):
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
class YoloClientV8n(YoloClient):
|
class YoloDetectorClientV8n(YoloDetectorClient):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(YoloModel.V8N)
|
super().__init__(YoloModel.V8N)
|
||||||
|
|
||||||
|
|
||||||
class YoloClientV11n(YoloClient):
|
class YoloDetectorClientV11n(YoloDetectorClient):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(YoloModel.V11N)
|
super().__init__(YoloModel.V11N)
|
||||||
|
|
||||||
|
|
||||||
class YoloClientV11m(YoloClient):
|
class YoloDetectorClientV11s(YoloDetectorClient):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(YoloModel.V11S)
|
||||||
|
|
||||||
|
|
||||||
|
class YoloDetectorClientV11m(YoloDetectorClient):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(YoloModel.V11M)
|
super().__init__(YoloModel.V11M)
|
||||||
|
44
deepface/models/facial_recognition/Yolo.py
Normal file
44
deepface/models/facial_recognition/Yolo.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# built-in dependencies
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
# 3rd party dependencies
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# project dependencies
|
||||||
|
from deepface.models.YoloClientBase import YoloClientBase
|
||||||
|
from deepface.models.YoloModel import YoloModel
|
||||||
|
from deepface.models.FacialRecognition import FacialRecognition
|
||||||
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
|
logger = Logger()
|
||||||
|
|
||||||
|
|
||||||
|
class YoloFacialRecognitionClient(YoloClientBase, FacialRecognition):
|
||||||
|
def __init__(self, model: YoloModel):
|
||||||
|
super().__init__(model)
|
||||||
|
self.model_name = "Yolo"
|
||||||
|
self.input_shape = None
|
||||||
|
self.output_shape = 512
|
||||||
|
|
||||||
|
def forward(self, img: np.ndarray) -> List[float]:
|
||||||
|
return self.model.embed(img)[0].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class YoloFacialRecognitionClientV8n(YoloFacialRecognitionClient):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(YoloModel.V8N)
|
||||||
|
|
||||||
|
|
||||||
|
class YoloFacialRecognitionClientV11n(YoloFacialRecognitionClient):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(YoloModel.V11N)
|
||||||
|
|
||||||
|
|
||||||
|
class YoloFacialRecognitionClientV11s(YoloFacialRecognitionClient):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(YoloModel.V11S)
|
||||||
|
|
||||||
|
|
||||||
|
class YoloFacialRecognitionClientV11m(YoloFacialRecognitionClient):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(YoloModel.V11M)
|
@ -35,7 +35,7 @@ def analyze(
|
|||||||
Set to False to avoid the exception for low-resolution images (default is True).
|
Set to False to avoid the exception for low-resolution images (default is True).
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
|
|
||||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||||
|
@ -38,7 +38,7 @@ def extract_faces(
|
|||||||
as a string, numpy array (BGR), or base64 encoded images.
|
as a string, numpy array (BGR), or base64 encoded images.
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv)
|
(default is opencv)
|
||||||
|
|
||||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||||
|
@ -12,6 +12,7 @@ from deepface.models.facial_recognition import (
|
|||||||
Dlib,
|
Dlib,
|
||||||
Facenet,
|
Facenet,
|
||||||
GhostFaceNet,
|
GhostFaceNet,
|
||||||
|
Yolo as YoloFacialRecognition,
|
||||||
)
|
)
|
||||||
from deepface.models.face_detection import (
|
from deepface.models.face_detection import (
|
||||||
FastMtCnn,
|
FastMtCnn,
|
||||||
@ -21,7 +22,7 @@ from deepface.models.face_detection import (
|
|||||||
Dlib as DlibDetector,
|
Dlib as DlibDetector,
|
||||||
RetinaFace,
|
RetinaFace,
|
||||||
Ssd,
|
Ssd,
|
||||||
Yolo,
|
Yolo as YoloFaceDetector,
|
||||||
YuNet,
|
YuNet,
|
||||||
CenterFace,
|
CenterFace,
|
||||||
)
|
)
|
||||||
@ -38,7 +39,7 @@ def build_model(task: str, model_name: str) -> Any:
|
|||||||
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
||||||
ArcFace, SFace, GhostFaceNet for face recognition
|
ArcFace, SFace, GhostFaceNet for face recognition
|
||||||
- Age, Gender, Emotion, Race for facial attributes
|
- Age, Gender, Emotion, Race for facial attributes
|
||||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11m', yunet,
|
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11s', 'yolov11m', yunet,
|
||||||
fastmtcnn or centerface for face detectors
|
fastmtcnn or centerface for face detectors
|
||||||
- Fasnet for spoofing
|
- Fasnet for spoofing
|
||||||
Returns:
|
Returns:
|
||||||
@ -60,6 +61,10 @@ def build_model(task: str, model_name: str) -> Any:
|
|||||||
"ArcFace": ArcFace.ArcFaceClient,
|
"ArcFace": ArcFace.ArcFaceClient,
|
||||||
"SFace": SFace.SFaceClient,
|
"SFace": SFace.SFaceClient,
|
||||||
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient,
|
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient,
|
||||||
|
"yolov8": YoloFacialRecognition.YoloFacialRecognitionClientV8n,
|
||||||
|
"yolov11n": YoloFacialRecognition.YoloFacialRecognitionClientV11n,
|
||||||
|
"yolov11s": YoloFacialRecognition.YoloFacialRecognitionClientV11s,
|
||||||
|
"yolov11m": YoloFacialRecognition.YoloFacialRecognitionClientV11m
|
||||||
},
|
},
|
||||||
"spoofing": {
|
"spoofing": {
|
||||||
"Fasnet": FasNet.Fasnet,
|
"Fasnet": FasNet.Fasnet,
|
||||||
@ -77,9 +82,10 @@ def build_model(task: str, model_name: str) -> Any:
|
|||||||
"dlib": DlibDetector.DlibClient,
|
"dlib": DlibDetector.DlibClient,
|
||||||
"retinaface": RetinaFace.RetinaFaceClient,
|
"retinaface": RetinaFace.RetinaFaceClient,
|
||||||
"mediapipe": MediaPipe.MediaPipeClient,
|
"mediapipe": MediaPipe.MediaPipeClient,
|
||||||
"yolov8": Yolo.YoloClientV8n,
|
"yolov8": YoloFaceDetector.YoloDetectorClientV8n,
|
||||||
"yolov11n": Yolo.YoloClientV11n,
|
"yolov11n": YoloFaceDetector.YoloDetectorClientV11n,
|
||||||
"yolov11m": Yolo.YoloClientV11m,
|
"yolov11s": YoloFaceDetector.YoloDetectorClientV11s,
|
||||||
|
"yolov11m": YoloFaceDetector.YoloDetectorClientV11m,
|
||||||
"yunet": YuNet.YuNetClient,
|
"yunet": YuNet.YuNetClient,
|
||||||
"fastmtcnn": FastMtCnn.FastMtCnnClient,
|
"fastmtcnn": FastMtCnn.FastMtCnnClient,
|
||||||
"centerface": CenterFace.CenterFaceClient,
|
"centerface": CenterFace.CenterFaceClient,
|
||||||
|
@ -54,7 +54,7 @@ def find(
|
|||||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n','yolov11m', 'centerface' or 'skip'.
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'.
|
||||||
|
|
||||||
align (boolean): Perform alignment based on the eye positions.
|
align (boolean): Perform alignment based on the eye positions.
|
||||||
|
|
||||||
@ -483,7 +483,7 @@ def find_batched(
|
|||||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'.
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'.
|
||||||
|
|
||||||
align (boolean): Perform alignment based on the eye positions.
|
align (boolean): Perform alignment based on the eye positions.
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ def represent(
|
|||||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'.
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'.
|
||||||
|
|
||||||
align (boolean): Perform alignment based on the eye positions.
|
align (boolean): Perform alignment based on the eye positions.
|
||||||
|
|
||||||
@ -122,11 +122,12 @@ def represent(
|
|||||||
confidence = img_obj["confidence"]
|
confidence = img_obj["confidence"]
|
||||||
|
|
||||||
# resize to expected shape of ml model
|
# resize to expected shape of ml model
|
||||||
img = preprocessing.resize_image(
|
if target_size is not None:
|
||||||
img=img,
|
img = preprocessing.resize_image(
|
||||||
# thanks to DeepId (!)
|
img=img,
|
||||||
target_size=(target_size[1], target_size[0]),
|
# thanks to DeepId (!)
|
||||||
)
|
target_size=(target_size[1], target_size[0]),
|
||||||
|
)
|
||||||
|
|
||||||
# custom normalization
|
# custom normalization
|
||||||
img = preprocessing.normalize_input(img=img, normalization=normalization)
|
img = preprocessing.normalize_input(img=img, normalization=normalization)
|
||||||
|
@ -45,7 +45,7 @@ def analysis(
|
|||||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
|
|
||||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||||
@ -192,7 +192,7 @@ def search_identity(
|
|||||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||||
'euclidean', 'euclidean_l2' (default is cosine).
|
'euclidean', 'euclidean_l2' (default is cosine).
|
||||||
@ -374,7 +374,7 @@ def grab_facial_areas(
|
|||||||
Args:
|
Args:
|
||||||
img (np.ndarray): image itself
|
img (np.ndarray): image itself
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
threshold (int): threshold for facial area, discard smaller ones
|
threshold (int): threshold for facial area, discard smaller ones
|
||||||
Returns
|
Returns
|
||||||
@ -443,7 +443,7 @@ def perform_facial_recognition(
|
|||||||
db_path (string): Path to the folder containing image files. All detected faces
|
db_path (string): Path to the folder containing image files. All detected faces
|
||||||
in the database will be considered in the decision-making process.
|
in the database will be considered in the decision-making process.
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv).
|
(default is opencv).
|
||||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||||
'euclidean', 'euclidean_l2' (default is cosine).
|
'euclidean', 'euclidean_l2' (default is cosine).
|
||||||
|
@ -47,7 +47,7 @@ def verify(
|
|||||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'
|
||||||
(default is opencv)
|
(default is opencv)
|
||||||
|
|
||||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||||
|
@ -22,6 +22,10 @@ model_names = [
|
|||||||
"ArcFace",
|
"ArcFace",
|
||||||
"SFace",
|
"SFace",
|
||||||
"GhostFaceNet",
|
"GhostFaceNet",
|
||||||
|
"yolov8",
|
||||||
|
"yolov11n",
|
||||||
|
"yolov11s",
|
||||||
|
"yolov11m"
|
||||||
]
|
]
|
||||||
|
|
||||||
detector_backends = [
|
detector_backends = [
|
||||||
@ -35,6 +39,7 @@ detector_backends = [
|
|||||||
"yunet",
|
"yunet",
|
||||||
"yolov8",
|
"yolov8",
|
||||||
"yolov11n",
|
"yolov11n",
|
||||||
|
"yolov11s",
|
||||||
"yolov11m",
|
"yolov11m",
|
||||||
"centerface",
|
"centerface",
|
||||||
]
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user