mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 19:45:21 +00:00
Merge pull request #758 from Vincent-Stragier/YOLOv8
Integration of YOLOv8-face close #732
This commit is contained in:
commit
0b8e5ca472
@ -194,7 +194,7 @@ Age model got ± 4.65 MAE; gender model got 97.44% accuracy, 96.29% precision an
|
|||||||
|
|
||||||
**Face Detectors** - [`Demo`](https://youtu.be/GZ2p2hj2H5k)
|
**Face Detectors** - [`Demo`](https://youtu.be/GZ2p2hj2H5k)
|
||||||
|
|
||||||
Face detection and alignment are important early stages of a modern face recognition pipeline. Experiments show that just alignment increases the face recognition accuracy almost 1%. [`OpenCV`](https://sefiks.com/2020/02/23/face-alignment-for-face-recognition-in-python-within-opencv/), [`SSD`](https://sefiks.com/2020/08/25/deep-face-detection-with-opencv-in-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), [`MTCNN`](https://sefiks.com/2020/09/09/deep-face-detection-with-mtcnn-in-python/), [`RetinaFace`](https://sefiks.com/2021/04/27/deep-face-detection-with-retinaface-in-python/) and [`MediaPipe`](https://sefiks.com/2022/01/14/deep-face-detection-with-mediapipe/) detectors are wrapped in deepface.
|
Face detection and alignment are important early stages of a modern face recognition pipeline. Experiments show that just alignment increases the face recognition accuracy almost 1%. [`OpenCV`](https://sefiks.com/2020/02/23/face-alignment-for-face-recognition-in-python-within-opencv/), [`SSD`](https://sefiks.com/2020/08/25/deep-face-detection-with-opencv-in-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), [`MTCNN`](https://sefiks.com/2020/09/09/deep-face-detection-with-mtcnn-in-python/), [`RetinaFace`](https://sefiks.com/2021/04/27/deep-face-detection-with-retinaface-in-python/), [`MediaPipe`](https://sefiks.com/2022/01/14/deep-face-detection-with-mediapipe/) and [`YOLOv8 Face`](https://github.com/derronqi/yolov8-face) detectors are wrapped in deepface.
|
||||||
|
|
||||||
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/detector-portfolio-v3.jpg" width="95%" height="95%"></p>
|
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/detector-portfolio-v3.jpg" width="95%" height="95%"></p>
|
||||||
|
|
||||||
@ -207,7 +207,8 @@ backends = [
|
|||||||
'dlib',
|
'dlib',
|
||||||
'mtcnn',
|
'mtcnn',
|
||||||
'retinaface',
|
'retinaface',
|
||||||
'mediapipe'
|
'mediapipe',
|
||||||
|
'yolov8',
|
||||||
]
|
]
|
||||||
|
|
||||||
#face verification
|
#face verification
|
||||||
|
@ -116,7 +116,7 @@ def verify(
|
|||||||
This might be convenient for low resolution images.
|
This might be convenient for low resolution images.
|
||||||
|
|
||||||
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
|
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
|
||||||
dlib or mediapipe
|
dlib, mediapipe or yolov8.
|
||||||
|
|
||||||
align (boolean): alignment according to the eye positions.
|
align (boolean): alignment according to the eye positions.
|
||||||
|
|
||||||
@ -251,7 +251,7 @@ def analyze(
|
|||||||
resolution images.
|
resolution images.
|
||||||
|
|
||||||
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
|
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
|
||||||
dlib or mediapipe.
|
dlib, mediapipe or yolov8.
|
||||||
|
|
||||||
align (boolean): alignment according to the eye positions.
|
align (boolean): alignment according to the eye positions.
|
||||||
|
|
||||||
@ -429,7 +429,7 @@ def find(
|
|||||||
resolution images.
|
resolution images.
|
||||||
|
|
||||||
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
|
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
|
||||||
dlib or mediapipe
|
dlib, mediapipe or yolov8.
|
||||||
|
|
||||||
align (boolean): alignment according to the eye positions.
|
align (boolean): alignment according to the eye positions.
|
||||||
|
|
||||||
@ -456,6 +456,7 @@ def find(
|
|||||||
file_name = file_name.replace("-", "_").lower()
|
file_name = file_name.replace("-", "_").lower()
|
||||||
|
|
||||||
if path.exists(db_path + "/" + file_name):
|
if path.exists(db_path + "/" + file_name):
|
||||||
|
|
||||||
if not silent:
|
if not silent:
|
||||||
print(
|
print(
|
||||||
f"WARNING: Representations for images in {db_path} folder were previously stored"
|
f"WARNING: Representations for images in {db_path} folder were previously stored"
|
||||||
@ -640,7 +641,7 @@ def represent(
|
|||||||
This might be convenient for low resolution images.
|
This might be convenient for low resolution images.
|
||||||
|
|
||||||
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
|
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
|
||||||
dlib or mediapipe
|
dlib, mediapipe or yolov8.
|
||||||
|
|
||||||
align (boolean): alignment according to the eye positions.
|
align (boolean): alignment according to the eye positions.
|
||||||
|
|
||||||
@ -725,7 +726,7 @@ def stream(
|
|||||||
model_name (string): VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
model_name (string): VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
||||||
ArcFace, SFace
|
ArcFace, SFace
|
||||||
|
|
||||||
detector_backend (string): opencv, retinaface, mtcnn, ssd, dlib or mediapipe
|
detector_backend (string): opencv, retinaface, mtcnn, ssd, dlib, mediapipe or yolov8.
|
||||||
|
|
||||||
distance_metric (string): cosine, euclidean, euclidean_l2
|
distance_metric (string): cosine, euclidean, euclidean_l2
|
||||||
|
|
||||||
|
@ -9,11 +9,11 @@ from deepface.detectors import (
|
|||||||
MtcnnWrapper,
|
MtcnnWrapper,
|
||||||
RetinaFaceWrapper,
|
RetinaFaceWrapper,
|
||||||
MediapipeWrapper,
|
MediapipeWrapper,
|
||||||
|
YoloWrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_model(detector_backend):
|
def build_model(detector_backend):
|
||||||
|
|
||||||
global face_detector_obj # singleton design pattern
|
global face_detector_obj # singleton design pattern
|
||||||
|
|
||||||
backends = {
|
backends = {
|
||||||
@ -23,6 +23,7 @@ def build_model(detector_backend):
|
|||||||
"mtcnn": MtcnnWrapper.build_model,
|
"mtcnn": MtcnnWrapper.build_model,
|
||||||
"retinaface": RetinaFaceWrapper.build_model,
|
"retinaface": RetinaFaceWrapper.build_model,
|
||||||
"mediapipe": MediapipeWrapper.build_model,
|
"mediapipe": MediapipeWrapper.build_model,
|
||||||
|
"yolov8": YoloWrapper.build_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
if not "face_detector_obj" in globals():
|
if not "face_detector_obj" in globals():
|
||||||
@ -42,20 +43,22 @@ def build_model(detector_backend):
|
|||||||
|
|
||||||
|
|
||||||
def detect_face(face_detector, detector_backend, img, align=True):
|
def detect_face(face_detector, detector_backend, img, align=True):
|
||||||
|
|
||||||
obj = detect_faces(face_detector, detector_backend, img, align)
|
obj = detect_faces(face_detector, detector_backend, img, align)
|
||||||
|
|
||||||
if len(obj) > 0:
|
if len(obj) > 0:
|
||||||
face, region, confidence = obj[0] # discard multiple faces
|
face, region, confidence = obj[0] # discard multiple faces
|
||||||
|
|
||||||
|
# If no face is detected, set face to None,
|
||||||
|
# image region to full image, and confidence to 0.
|
||||||
else: # len(obj) == 0
|
else: # len(obj) == 0
|
||||||
face = None
|
face = None
|
||||||
region = [0, 0, img.shape[1], img.shape[0]]
|
region = [0, 0, img.shape[1], img.shape[0]]
|
||||||
|
confidence = 0
|
||||||
|
|
||||||
return face, region, confidence
|
return face, region, confidence
|
||||||
|
|
||||||
|
|
||||||
def detect_faces(face_detector, detector_backend, img, align=True):
|
def detect_faces(face_detector, detector_backend, img, align=True):
|
||||||
|
|
||||||
backends = {
|
backends = {
|
||||||
"opencv": OpenCvWrapper.detect_face,
|
"opencv": OpenCvWrapper.detect_face,
|
||||||
"ssd": SsdWrapper.detect_face,
|
"ssd": SsdWrapper.detect_face,
|
||||||
@ -63,6 +66,7 @@ def detect_faces(face_detector, detector_backend, img, align=True):
|
|||||||
"mtcnn": MtcnnWrapper.detect_face,
|
"mtcnn": MtcnnWrapper.detect_face,
|
||||||
"retinaface": RetinaFaceWrapper.detect_face,
|
"retinaface": RetinaFaceWrapper.detect_face,
|
||||||
"mediapipe": MediapipeWrapper.detect_face,
|
"mediapipe": MediapipeWrapper.detect_face,
|
||||||
|
"yolov8": YoloWrapper.detect_face,
|
||||||
}
|
}
|
||||||
|
|
||||||
detect_face_fn = backends.get(detector_backend)
|
detect_face_fn = backends.get(detector_backend)
|
||||||
@ -76,7 +80,6 @@ def detect_faces(face_detector, detector_backend, img, align=True):
|
|||||||
|
|
||||||
|
|
||||||
def alignment_procedure(img, left_eye, right_eye):
|
def alignment_procedure(img, left_eye, right_eye):
|
||||||
|
|
||||||
# this function aligns given face in img based on left and right eye coordinates
|
# this function aligns given face in img based on left and right eye coordinates
|
||||||
|
|
||||||
left_eye_x, left_eye_y = left_eye
|
left_eye_x, left_eye_y = left_eye
|
||||||
@ -104,7 +107,6 @@ def alignment_procedure(img, left_eye, right_eye):
|
|||||||
# apply cosine rule
|
# apply cosine rule
|
||||||
|
|
||||||
if b != 0 and c != 0: # this multiplication causes division by zero in cos_a calculation
|
if b != 0 and c != 0: # this multiplication causes division by zero in cos_a calculation
|
||||||
|
|
||||||
cos_a = (b * b + c * c - a * a) / (2 * b * c)
|
cos_a = (b * b + c * c - a * a) / (2 * b * c)
|
||||||
angle = np.arccos(cos_a) # angle in radian
|
angle = np.arccos(cos_a) # angle in radian
|
||||||
angle = (angle * 180) / math.pi # radian to degree
|
angle = (angle * 180) / math.pi # radian to degree
|
||||||
|
62
deepface/detectors/YoloWrapper.py
Normal file
62
deepface/detectors/YoloWrapper.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from deepface.detectors import FaceDetector
|
||||||
|
|
||||||
|
# Model's weights paths
|
||||||
|
PATH = "/.deepface/weights/yolov8n-face.pt"
|
||||||
|
|
||||||
|
# Google Drive URL
|
||||||
|
WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb"
|
||||||
|
|
||||||
|
# Confidence thresholds for landmarks detection
|
||||||
|
# used in alignment_procedure function
|
||||||
|
LANDMARKS_CONFIDENCE_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def build_model():
|
||||||
|
"""Build YOLO (yolov8n-face) model"""
|
||||||
|
import gdown
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Import the Ultralytics YOLO model
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
from deepface.commons.functions import get_deepface_home
|
||||||
|
weight_path = f"{get_deepface_home()}{PATH}"
|
||||||
|
|
||||||
|
# Download the model's weights if they don't exist
|
||||||
|
if not os.path.isfile(weight_path):
|
||||||
|
gdown.download(WEIGHT_URL, weight_path, quiet=False)
|
||||||
|
print(f"Downloaded YOLO model {os.path.basename(weight_path)}")
|
||||||
|
|
||||||
|
# Return face_detector
|
||||||
|
return YOLO(weight_path)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_face(face_detector, img, align=False):
|
||||||
|
resp = []
|
||||||
|
|
||||||
|
# Detect faces
|
||||||
|
results = face_detector.predict(
|
||||||
|
img, verbose=False, show=False, conf=0.25)[0]
|
||||||
|
|
||||||
|
# For each face, extract the bounding box, the landmarks and confidence
|
||||||
|
for result in results:
|
||||||
|
# Extract the bounding box and the confidence
|
||||||
|
x, y, w, h = result.boxes.xywh.tolist()[0]
|
||||||
|
confidence = result.boxes.conf.tolist()[0]
|
||||||
|
|
||||||
|
x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h)
|
||||||
|
detected_face = img[y: y + h, x: x + w].copy()
|
||||||
|
|
||||||
|
if align:
|
||||||
|
# Extract landmarks
|
||||||
|
left_eye, right_eye, _, _, _ = result.keypoints.tolist()
|
||||||
|
# Check the landmarks confidence before alignment
|
||||||
|
if (left_eye[2] > LANDMARKS_CONFIDENCE_THRESHOLD and
|
||||||
|
right_eye[2] > LANDMARKS_CONFIDENCE_THRESHOLD):
|
||||||
|
detected_face = FaceDetector.alignment_procedure(
|
||||||
|
detected_face, left_eye[:2], right_eye[:2]
|
||||||
|
)
|
||||||
|
|
||||||
|
resp.append((detected_face, [x, y, w, h], confidence))
|
||||||
|
|
||||||
|
return resp
|
@ -1,3 +1,4 @@
|
|||||||
opencv-contrib-python>=4.3.0.36
|
opencv-contrib-python>=4.3.0.36
|
||||||
mediapipe>=0.8.7.3
|
mediapipe>=0.8.7.3
|
||||||
dlib>=19.20.0
|
dlib>=19.20.0
|
||||||
|
ultralytics @ git+https://github.com/derronqi/yolov8-face.git@b623989575bdb78601b5ca717851e3d63ca9e01c
|
Loading…
x
Reference in New Issue
Block a user