diff --git a/README.md b/README.md index a0cd133..106bdc0 100644 --- a/README.md +++ b/README.md @@ -194,7 +194,7 @@ Age model got ± 4.65 MAE; gender model got 97.44% accuracy, 96.29% precision an **Face Detectors** - [`Demo`](https://youtu.be/GZ2p2hj2H5k) -Face detection and alignment are important early stages of a modern face recognition pipeline. Experiments show that just alignment increases the face recognition accuracy almost 1%. [`OpenCV`](https://sefiks.com/2020/02/23/face-alignment-for-face-recognition-in-python-within-opencv/), [`SSD`](https://sefiks.com/2020/08/25/deep-face-detection-with-opencv-in-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), [`MTCNN`](https://sefiks.com/2020/09/09/deep-face-detection-with-mtcnn-in-python/), [`RetinaFace`](https://sefiks.com/2021/04/27/deep-face-detection-with-retinaface-in-python/) and [`MediaPipe`](https://sefiks.com/2022/01/14/deep-face-detection-with-mediapipe/) detectors are wrapped in deepface. +Face detection and alignment are important early stages of a modern face recognition pipeline. Experiments show that just alignment increases the face recognition accuracy almost 1%. [`OpenCV`](https://sefiks.com/2020/02/23/face-alignment-for-face-recognition-in-python-within-opencv/), [`SSD`](https://sefiks.com/2020/08/25/deep-face-detection-with-opencv-in-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), [`MTCNN`](https://sefiks.com/2020/09/09/deep-face-detection-with-mtcnn-in-python/), [`RetinaFace`](https://sefiks.com/2021/04/27/deep-face-detection-with-retinaface-in-python/), [`MediaPipe`](https://sefiks.com/2022/01/14/deep-face-detection-with-mediapipe/) and [`YOLOv8 Face`](https://github.com/derronqi/yolov8-face) detectors are wrapped in deepface.

@@ -207,7 +207,8 @@ backends = [ 'dlib', 'mtcnn', 'retinaface', - 'mediapipe' + 'mediapipe', + 'yolov8', ] #face verification diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index d38a810..38aa545 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -116,7 +116,7 @@ def verify( This might be convenient for low resolution images. detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd, - dlib or mediapipe + dlib, mediapipe or yolov8. align (boolean): alignment according to the eye positions. @@ -251,7 +251,7 @@ def analyze( resolution images. detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd, - dlib or mediapipe. + dlib, mediapipe or yolov8. align (boolean): alignment according to the eye positions. @@ -429,7 +429,7 @@ def find( resolution images. detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd, - dlib or mediapipe + dlib, mediapipe or yolov8. align (boolean): alignment according to the eye positions. @@ -456,6 +456,7 @@ def find( file_name = file_name.replace("-", "_").lower() if path.exists(db_path + "/" + file_name): + if not silent: print( f"WARNING: Representations for images in {db_path} folder were previously stored" @@ -640,7 +641,7 @@ def represent( This might be convenient for low resolution images. detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd, - dlib or mediapipe + dlib, mediapipe or yolov8. align (boolean): alignment according to the eye positions. @@ -725,7 +726,7 @@ def stream( model_name (string): VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace - detector_backend (string): opencv, retinaface, mtcnn, ssd, dlib or mediapipe + detector_backend (string): opencv, retinaface, mtcnn, ssd, dlib, mediapipe or yolov8. distance_metric (string): cosine, euclidean, euclidean_l2 diff --git a/deepface/detectors/FaceDetector.py b/deepface/detectors/FaceDetector.py index 367f588..6e7f258 100644 --- a/deepface/detectors/FaceDetector.py +++ b/deepface/detectors/FaceDetector.py @@ -9,11 +9,11 @@ from deepface.detectors import ( MtcnnWrapper, RetinaFaceWrapper, MediapipeWrapper, + YoloWrapper, ) def build_model(detector_backend): - global face_detector_obj # singleton design pattern backends = { @@ -23,6 +23,7 @@ def build_model(detector_backend): "mtcnn": MtcnnWrapper.build_model, "retinaface": RetinaFaceWrapper.build_model, "mediapipe": MediapipeWrapper.build_model, + "yolov8": YoloWrapper.build_model, } if not "face_detector_obj" in globals(): @@ -42,20 +43,22 @@ def build_model(detector_backend): def detect_face(face_detector, detector_backend, img, align=True): - obj = detect_faces(face_detector, detector_backend, img, align) if len(obj) > 0: face, region, confidence = obj[0] # discard multiple faces + + # If no face is detected, set face to None, + # image region to full image, and confidence to 0. else: # len(obj) == 0 face = None region = [0, 0, img.shape[1], img.shape[0]] + confidence = 0 return face, region, confidence def detect_faces(face_detector, detector_backend, img, align=True): - backends = { "opencv": OpenCvWrapper.detect_face, "ssd": SsdWrapper.detect_face, @@ -63,6 +66,7 @@ def detect_faces(face_detector, detector_backend, img, align=True): "mtcnn": MtcnnWrapper.detect_face, "retinaface": RetinaFaceWrapper.detect_face, "mediapipe": MediapipeWrapper.detect_face, + "yolov8": YoloWrapper.detect_face, } detect_face_fn = backends.get(detector_backend) @@ -76,7 +80,6 @@ def detect_faces(face_detector, detector_backend, img, align=True): def alignment_procedure(img, left_eye, right_eye): - # this function aligns given face in img based on left and right eye coordinates left_eye_x, left_eye_y = left_eye @@ -104,7 +107,6 @@ def alignment_procedure(img, left_eye, right_eye): # apply cosine rule if b != 0 and c != 0: # this multiplication causes division by zero in cos_a calculation - cos_a = (b * b + c * c - a * a) / (2 * b * c) angle = np.arccos(cos_a) # angle in radian angle = (angle * 180) / math.pi # radian to degree diff --git a/deepface/detectors/YoloWrapper.py b/deepface/detectors/YoloWrapper.py new file mode 100644 index 0000000..fd38a9b --- /dev/null +++ b/deepface/detectors/YoloWrapper.py @@ -0,0 +1,62 @@ +from deepface.detectors import FaceDetector + +# Model's weights paths +PATH = "/.deepface/weights/yolov8n-face.pt" + +# Google Drive URL +WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb" + +# Confidence thresholds for landmarks detection +# used in alignment_procedure function +LANDMARKS_CONFIDENCE_THRESHOLD = 0.5 + + +def build_model(): + """Build YOLO (yolov8n-face) model""" + import gdown + import os + + # Import the Ultralytics YOLO model + from ultralytics import YOLO + + from deepface.commons.functions import get_deepface_home + weight_path = f"{get_deepface_home()}{PATH}" + + # Download the model's weights if they don't exist + if not os.path.isfile(weight_path): + gdown.download(WEIGHT_URL, weight_path, quiet=False) + print(f"Downloaded YOLO model {os.path.basename(weight_path)}") + + # Return face_detector + return YOLO(weight_path) + + +def detect_face(face_detector, img, align=False): + resp = [] + + # Detect faces + results = face_detector.predict( + img, verbose=False, show=False, conf=0.25)[0] + + # For each face, extract the bounding box, the landmarks and confidence + for result in results: + # Extract the bounding box and the confidence + x, y, w, h = result.boxes.xywh.tolist()[0] + confidence = result.boxes.conf.tolist()[0] + + x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h) + detected_face = img[y: y + h, x: x + w].copy() + + if align: + # Extract landmarks + left_eye, right_eye, _, _, _ = result.keypoints.tolist() + # Check the landmarks confidence before alignment + if (left_eye[2] > LANDMARKS_CONFIDENCE_THRESHOLD and + right_eye[2] > LANDMARKS_CONFIDENCE_THRESHOLD): + detected_face = FaceDetector.alignment_procedure( + detected_face, left_eye[:2], right_eye[:2] + ) + + resp.append((detected_face, [x, y, w, h], confidence)) + + return resp diff --git a/requirements_additional.txt b/requirements_additional.txt index ee752b9..cae0790 100644 --- a/requirements_additional.txt +++ b/requirements_additional.txt @@ -1,3 +1,4 @@ opencv-contrib-python>=4.3.0.36 mediapipe>=0.8.7.3 -dlib>=19.20.0 \ No newline at end of file +dlib>=19.20.0 +ultralytics @ git+https://github.com/derronqi/yolov8-face.git@b623989575bdb78601b5ca717851e3d63ca9e01c \ No newline at end of file