Merge pull request #1184 from serengil/feat-task-1204-centerface

Feat task 1204 centerface
This commit is contained in:
Sefik Ilkin Serengil 2024-04-12 17:09:42 +01:00 committed by GitHub
commit b2ed849183
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 268 additions and 30 deletions

View File

@ -194,9 +194,9 @@ Age model got ± 4.65 MAE; gender model got 97.44% accuracy, 96.29% precision an
**Face Detectors** - [`Demo`](https://youtu.be/GZ2p2hj2H5k) **Face Detectors** - [`Demo`](https://youtu.be/GZ2p2hj2H5k)
Face detection and alignment are important early stages of a modern face recognition pipeline. Experiments show that just alignment increases the face recognition accuracy almost 1%. [`OpenCV`](https://sefiks.com/2020/02/23/face-alignment-for-face-recognition-in-python-within-opencv/), [`SSD`](https://sefiks.com/2020/08/25/deep-face-detection-with-opencv-in-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), [`MTCNN`](https://sefiks.com/2020/09/09/deep-face-detection-with-mtcnn-in-python/), [`Faster MTCNN`](https://github.com/timesler/facenet-pytorch), [`RetinaFace`](https://sefiks.com/2021/04/27/deep-face-detection-with-retinaface-in-python/), [`MediaPipe`](https://sefiks.com/2022/01/14/deep-face-detection-with-mediapipe/), [`YOLOv8 Face`](https://github.com/derronqi/yolov8-face) and [`YuNet`](https://github.com/ShiqiYu/libfacedetection) detectors are wrapped in deepface. Face detection and alignment are important early stages of a modern face recognition pipeline. Experiments show that just alignment increases the face recognition accuracy almost 1%. [`OpenCV`](https://sefiks.com/2020/02/23/face-alignment-for-face-recognition-in-python-within-opencv/), [`Ssd`](https://sefiks.com/2020/08/25/deep-face-detection-with-opencv-in-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), [`MtCnn`](https://sefiks.com/2020/09/09/deep-face-detection-with-mtcnn-in-python/), `Faster MTCNN`, [`RetinaFace`](https://sefiks.com/2021/04/27/deep-face-detection-with-retinaface-in-python/), [`MediaPipe`](https://sefiks.com/2022/01/14/deep-face-detection-with-mediapipe/), `Yolo`, `YuNet` and `CenterFace` detectors are wrapped in deepface.
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/detector-portfolio-v5.jpg" width="95%" height="95%"></p> <p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/detector-portfolio-v6.jpg" width="95%" height="95%"></p>
All deepface functions accept an optional detector backend input argument. You can switch among those detectors with this argument. OpenCV is the default detector. All deepface functions accept an optional detector backend input argument. You can switch among those detectors with this argument. OpenCV is the default detector.
@ -206,11 +206,12 @@ backends = [
'ssd', 'ssd',
'dlib', 'dlib',
'mtcnn', 'mtcnn',
'fastmtcnn',
'retinaface', 'retinaface',
'mediapipe', 'mediapipe',
'yolov8', 'yolov8',
'yunet', 'yunet',
'fastmtcnn', 'centerface',
] ]
#face verification #face verification
@ -317,9 +318,7 @@ You can also run these commands if you are running deepface with docker. Please
## FAQ and Troubleshooting ## FAQ and Troubleshooting
If you believe you have identified a bug or encountered a limitation in DeepFace that is not covered in the [existing issues](https://github.com/serengil/deepface/issues) or [closed issues](https://github.com/serengil/deepface/issues?q=is%3Aissue+is%3Aclosed), kindly open a new issue. Ensure that your submission includes clear and detailed reproduction steps, such as your Python version, your DeepFace version (provided by `DeepFace.__version__`), versions of dependent packages (provided by pip freeze), specifics of any exception messages, details about how you are calling DeepFace, and the input image(s) you are using. If you believe you have identified a bug or encountered a limitation in DeepFace that is not covered in the [existing issues](https://github.com/serengil/deepface/issues) or [closed issues](https://github.com/serengil/deepface/issues?q=is%3Aissue+is%3Aclosed), kindly open a new issue. Ensure that your submission includes clear and detailed reproduction steps, such as your Python version, your DeepFace version, versions of dependent packages (provided by pip freeze), specifics of any exception messages, details about how you are calling DeepFace, and the input image(s) you are using.
Additionally, it is possible to encounter issues due to recently released dependencies, primarily Python itself or TensorFlow. It is recommended to synchronize your dependencies with the versions [specified in my environment](https://github.com/serengil/deepface/blob/master/requirements_local) and [same python version](https://github.com/serengil/deepface/blob/master/Dockerfile#L2) not to have potential compatibility issues.
## Contribution ## Contribution
@ -351,7 +350,7 @@ If you use deepface in your research for facial recogntion purposes, please cite
pages = {23-27}, pages = {23-27},
year = {2020}, year = {2020},
doi = {10.1109/ASYU50717.2020.9259802}, doi = {10.1109/ASYU50717.2020.9259802},
url = {https://doi.org/10.1109/ASYU50717.2020.9259802}, url = {https://ieeexplore.ieee.org/document/9259802},
organization = {IEEE} organization = {IEEE}
} }
``` ```
@ -366,7 +365,7 @@ If you use deepface in your research for facial attribute analysis purposes such
pages = {1-4}, pages = {1-4},
year = {2021}, year = {2021},
doi = {10.1109/ICEET53442.2021.9659697}, doi = {10.1109/ICEET53442.2021.9659697},
url = {https://doi.org/10.1109/ICEET53442.2021.9659697}, url = {https://ieeexplore.ieee.org/document/9659697},
organization = {IEEE} organization = {IEEE}
} }
``` ```
@ -377,6 +376,10 @@ Also, if you use deepface in your GitHub projects, please add `deepface` in the
DeepFace is licensed under the MIT License - see [`LICENSE`](https://github.com/serengil/deepface/blob/master/LICENSE) for more details. DeepFace is licensed under the MIT License - see [`LICENSE`](https://github.com/serengil/deepface/blob/master/LICENSE) for more details.
DeepFace wraps some external face recognition models: [VGG-Face](http://www.robots.ox.ac.uk/~vgg/software/vgg_face/), [Facenet](https://github.com/davidsandberg/facenet/blob/master/LICENSE.md), [OpenFace](https://github.com/iwantooxxoox/Keras-OpenFace/blob/master/LICENSE), [DeepFace](https://github.com/swghosh/DeepFace), [DeepID](https://github.com/Ruoyiran/DeepID/blob/master/LICENSE.md), [ArcFace](https://github.com/leondgarse/Keras_insightface/blob/master/LICENSE), [Dlib](https://github.com/davisking/dlib/blob/master/dlib/LICENSE.txt), [SFace](https://github.com/opencv/opencv_zoo/blob/master/models/face_recognition_sface/LICENSE) and [GhostFaceNet](https://github.com/HamadYA/GhostFaceNets/blob/main/LICENSE). Besides, age, gender and race / ethnicity models were trained on the backbone of VGG-Face with transfer learning. Licence types will be inherited if you are going to use those models. Please check the license types of those models for production purposes. DeepFace wraps some external face recognition models: [VGG-Face](http://www.robots.ox.ac.uk/~vgg/software/vgg_face/), [Facenet](https://github.com/davidsandberg/facenet/blob/master/LICENSE.md), [OpenFace](https://github.com/iwantooxxoox/Keras-OpenFace/blob/master/LICENSE), [DeepFace](https://github.com/swghosh/DeepFace), [DeepID](https://github.com/Ruoyiran/DeepID/blob/master/LICENSE.md), [ArcFace](https://github.com/leondgarse/Keras_insightface/blob/master/LICENSE), [Dlib](https://github.com/davisking/dlib/blob/master/dlib/LICENSE.txt), [SFace](https://github.com/opencv/opencv_zoo/blob/master/models/face_recognition_sface/LICENSE) and [GhostFaceNet](https://github.com/HamadYA/GhostFaceNets/blob/main/LICENSE). Besides, age, gender and race / ethnicity models were trained on the backbone of VGG-Face with transfer learning.
Similarly, DeepFace wraps many face detectors: [OpenCv](https://github.com/opencv/opencv/blob/4.x/LICENSE), [Ssd](https://github.com/opencv/opencv/blob/master/LICENSE), [Dlib](https://github.com/davisking/dlib/blob/master/LICENSE.txt), [MtCnn](https://github.com/ipazc/mtcnn/blob/master/LICENSE), [Fast MtCnn](https://github.com/timesler/facenet-pytorch/blob/master/LICENSE.md), [RetinaFace](https://github.com/serengil/retinaface/blob/master/LICENSE), [MediaPipe](https://github.com/google/mediapipe/blob/master/LICENSE), [YuNet](https://github.com/ShiqiYu/libfacedetection/blob/master/LICENSE), [Yolo](https://github.com/derronqi/yolov8-face/blob/main/LICENSE) and [CenterFace](https://github.com/Star-Clouds/CenterFace/blob/master/LICENSE).
Licence types will be inherited if you are going to use those models. Please check the license types of those models for production purposes.
DeepFace [logo](https://thenounproject.com/term/face-recognition/2965879/) is created by [Adrien Coquet](https://thenounproject.com/coquet_adrien/) and it is licensed under [Creative Commons: By Attribution 3.0 License](https://creativecommons.org/licenses/by/3.0/). DeepFace [logo](https://thenounproject.com/term/face-recognition/2965879/) is created by [Adrien Coquet](https://thenounproject.com/coquet_adrien/) and it is licensed under [Creative Commons: By Attribution 3.0 License](https://creativecommons.org/licenses/by/3.0/).

View File

@ -88,7 +88,8 @@ def verify(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
@ -168,7 +169,8 @@ def analyze(
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
@ -272,7 +274,8 @@ def find(
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
align (boolean): Perform alignment based on the eye positions (default is True). align (boolean): Perform alignment based on the eye positions (default is True).
@ -348,7 +351,8 @@ def represent(
(default is True). (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
align (boolean): Perform alignment based on the eye positions (default is True). align (boolean): Perform alignment based on the eye positions (default is True).
@ -406,7 +410,8 @@ def stream(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
@ -454,7 +459,8 @@ def extract_faces(
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), or base64 encoded images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
@ -520,7 +526,8 @@ def detectFace(
added to resize the image (default is (224, 224)). added to resize the image (default is (224, 224)).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).

View File

@ -0,0 +1,217 @@
# built-in dependencies
import os
from typing import List
# 3rd party dependencies
import numpy as np
import cv2
import gdown
# project dependencies
from deepface.commons import folder_utils
from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
# pylint: disable=c-extension-no-member
WEIGHTS_URL = "https://github.com/Star-Clouds/CenterFace/raw/master/models/onnx/centerface.onnx"
class CenterFaceClient(Detector):
def __init__(self):
# BUG: model must be flushed for each call
# self.model = self.build_model()
pass
def build_model(self):
"""
Download pre-trained weights of CenterFace model if necessary and load built model
"""
weights_path = f"{folder_utils.get_deepface_home()}/.deepface/weights/centerface.onnx"
if not os.path.isfile(weights_path):
logger.info(f"Downloading CenterFace weights from {WEIGHTS_URL} to {weights_path}...")
try:
gdown.download(WEIGHTS_URL, weights_path, quiet=False)
except Exception as err:
raise ValueError(
f"Exception while downloading CenterFace weights from {WEIGHTS_URL}."
f"You may consider to download it to {weights_path} manually."
) from err
logger.info(f"CenterFace model is just downloaded to {os.path.basename(weights_path)}")
return CenterFace(weight_path=weights_path)
def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]:
"""
Detect and align face with CenterFace
Args:
img (np.ndarray): pre-loaded image as numpy array
Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
resp = []
threshold = float(os.getenv("CENTERFACE_THRESHOLD", "0.80"))
# BUG: model causes problematic results from 2nd call if it is not flushed
# detections, landmarks = self.model.forward(
# img, img.shape[0], img.shape[1], threshold=threshold
# )
detections, landmarks = self.build_model().forward(
img, img.shape[0], img.shape[1], threshold=threshold
)
for i, detection in enumerate(detections):
boxes, confidence = detection[:4], detection[4]
x = boxes[0]
y = boxes[1]
w = boxes[2] - x
h = boxes[3] - y
landmark = landmarks[i]
right_eye = (int(landmark[0]), int(landmark[1]))
left_eye = (int(landmark[2]), int(landmark[3]))
# nose = (int(landmark[4]), int(landmark [5]))
# mouth_right = (int(landmark[6]), int(landmark [7]))
# mouth_left = (int(landmark[8]), int(landmark [9]))
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=min(max(0, float(confidence)), 1.0),
)
resp.append(facial_area)
return resp
class CenterFace:
"""
This class is heavily inspired from
github.com/Star-Clouds/CenterFace/blob/master/prj-python/centerface.py
"""
def __init__(self, weight_path: str):
self.net = cv2.dnn.readNetFromONNX(weight_path)
self.img_h_new, self.img_w_new, self.scale_h, self.scale_w = 0, 0, 0, 0
def forward(self, img, height, width, threshold=0.5):
self.img_h_new, self.img_w_new, self.scale_h, self.scale_w = self.transform(height, width)
return self.inference_opencv(img, threshold)
def inference_opencv(self, img, threshold):
blob = cv2.dnn.blobFromImage(
img,
scalefactor=1.0,
size=(self.img_w_new, self.img_h_new),
mean=(0, 0, 0),
swapRB=True,
crop=False,
)
self.net.setInput(blob)
heatmap, scale, offset, lms = self.net.forward(["537", "538", "539", "540"])
return self.postprocess(heatmap, lms, offset, scale, threshold)
def transform(self, h, w):
img_h_new, img_w_new = int(np.ceil(h / 32) * 32), int(np.ceil(w / 32) * 32)
scale_h, scale_w = img_h_new / h, img_w_new / w
return img_h_new, img_w_new, scale_h, scale_w
def postprocess(self, heatmap, lms, offset, scale, threshold):
dets, lms = self.decode(
heatmap, scale, offset, lms, (self.img_h_new, self.img_w_new), threshold=threshold
)
if len(dets) > 0:
dets[:, 0:4:2], dets[:, 1:4:2] = (
dets[:, 0:4:2] / self.scale_w,
dets[:, 1:4:2] / self.scale_h,
)
lms[:, 0:10:2], lms[:, 1:10:2] = (
lms[:, 0:10:2] / self.scale_w,
lms[:, 1:10:2] / self.scale_h,
)
else:
dets = np.empty(shape=[0, 5], dtype=np.float32)
lms = np.empty(shape=[0, 10], dtype=np.float32)
return dets, lms
def decode(self, heatmap, scale, offset, landmark, size, threshold=0.1):
heatmap = np.squeeze(heatmap)
scale0, scale1 = scale[0, 0, :, :], scale[0, 1, :, :]
offset0, offset1 = offset[0, 0, :, :], offset[0, 1, :, :]
c0, c1 = np.where(heatmap > threshold)
boxes, lms = [], []
if len(c0) > 0:
# pylint:disable=consider-using-enumerate
for i in range(len(c0)):
s0, s1 = np.exp(scale0[c0[i], c1[i]]) * 4, np.exp(scale1[c0[i], c1[i]]) * 4
o0, o1 = offset0[c0[i], c1[i]], offset1[c0[i], c1[i]]
s = heatmap[c0[i], c1[i]]
x1, y1 = max(0, (c1[i] + o1 + 0.5) * 4 - s1 / 2), max(
0, (c0[i] + o0 + 0.5) * 4 - s0 / 2
)
x1, y1 = min(x1, size[1]), min(y1, size[0])
boxes.append([x1, y1, min(x1 + s1, size[1]), min(y1 + s0, size[0]), s])
lm = []
for j in range(5):
lm.append(landmark[0, j * 2 + 1, c0[i], c1[i]] * s1 + x1)
lm.append(landmark[0, j * 2, c0[i], c1[i]] * s0 + y1)
lms.append(lm)
boxes = np.asarray(boxes, dtype=np.float32)
keep = self.nms(boxes[:, :4], boxes[:, 4], 0.3)
boxes = boxes[keep, :]
lms = np.asarray(lms, dtype=np.float32)
lms = lms[keep, :]
return boxes, lms
def nms(self, boxes, scores, nms_thresh):
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = np.argsort(scores)[::-1]
num_detections = boxes.shape[0]
suppressed = np.zeros((num_detections,), dtype=bool)
keep = []
for _i in range(num_detections):
i = order[_i]
if suppressed[i]:
continue
keep.append(i)
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, num_detections):
j = order[_j]
if suppressed[j]:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0, xx2 - xx1 + 1)
h = max(0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= nms_thresh:
suppressed[j] = True
return keep

View File

@ -12,6 +12,7 @@ from deepface.detectors import (
Ssd, Ssd,
Yolo, Yolo,
YuNet, YuNet,
CenterFace,
) )
from deepface.commons import logger as log from deepface.commons import logger as log
@ -38,6 +39,7 @@ def build_model(detector_backend: str) -> Any:
"yolov8": Yolo.YoloClient, "yolov8": Yolo.YoloClient,
"yunet": YuNet.YuNetClient, "yunet": YuNet.YuNetClient,
"fastmtcnn": FastMtCnn.FastMtCnnClient, "fastmtcnn": FastMtCnn.FastMtCnnClient,
"centerface": CenterFace.CenterFaceClient,
} }
if not "face_detector_obj" in globals(): if not "face_detector_obj" in globals():
@ -93,7 +95,7 @@ def detect_faces(
expand_percentage = 0 expand_percentage = 0
# find facial areas of given image # find facial areas of given image
facial_areas = face_detector.detect_faces(img=img) facial_areas = face_detector.detect_faces(img)
results = [] results = []
for facial_area in facial_areas: for facial_area in facial_areas:

View File

@ -34,7 +34,8 @@ def analyze(
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).

View File

@ -33,7 +33,8 @@ def extract_faces(
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), or base64 encoded images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv) 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv)
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Default is True. Set to False to avoid the exception for low-resolution images. Default is True. Set to False to avoid the exception for low-resolution images.

View File

@ -52,7 +52,7 @@ def find(
Default is True. Set to False to avoid the exception for low-resolution images. Default is True. Set to False to avoid the exception for low-resolution images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip'. 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
align (boolean): Perform alignment based on the eye positions. align (boolean): Perform alignment based on the eye positions.

View File

@ -33,7 +33,7 @@ def represent(
Default is True. Set to False to avoid the exception for low-resolution images. Default is True. Set to False to avoid the exception for low-resolution images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip'. 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
align (boolean): Perform alignment based on the eye positions. align (boolean): Perform alignment based on the eye positions.

View File

@ -43,7 +43,8 @@ def analysis(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
@ -182,7 +183,8 @@ def search_identity(
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
Returns: Returns:
@ -349,7 +351,8 @@ def grab_facial_areas(
Args: Args:
img (np.ndarray): image itself img (np.ndarray): image itself
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
threshold (int): threshold for facial area, discard smaller ones threshold (int): threshold for facial area, discard smaller ones
Returns Returns
result (list): list of tuple with x, y, w and h coordinates result (list): list of tuple with x, y, w and h coordinates
@ -414,7 +417,8 @@ def perform_facial_recognition(
db_path (string): Path to the folder containing image files. All detected faces db_path (string): Path to the folder containing image files. All detected faces
in the database will be considered in the decision-making process. in the database will be considered in the decision-making process.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,

View File

@ -45,7 +45,8 @@ def verify(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' or 'skip' (default is opencv) 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv)
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).

Binary file not shown.

After

Width:  |  Height:  |  Size: 228 KiB

View File

@ -34,9 +34,9 @@ detector_backends = [
"retinaface", "retinaface",
"yunet", "yunet",
"yolov8", "yolov8",
"centerface",
] ]
# verification # verification
for model_name in model_names: for model_name in model_names:
obj = DeepFace.verify( obj = DeepFace.verify(
@ -60,7 +60,6 @@ dfs = DeepFace.find(
for df in dfs: for df in dfs:
logger.info(df) logger.info(df)
expand_areas = [0] expand_areas = [0]
img_paths = ["dataset/img11.jpg", "dataset/img11_reflection.jpg"] img_paths = ["dataset/img11.jpg", "dataset/img11_reflection.jpg"]
for expand_area in expand_areas: for expand_area in expand_areas:
@ -75,7 +74,7 @@ for expand_area in expand_areas:
) )
for face_obj in face_objs: for face_obj in face_objs:
face = face_obj["face"] face = face_obj["face"]
logger.info(detector_backend) logger.info(f"testing {img_path} with {detector_backend}")
logger.info(face_obj["facial_area"]) logger.info(face_obj["facial_area"])
logger.info(face_obj["confidence"]) logger.info(face_obj["confidence"])
@ -99,7 +98,10 @@ for expand_area in expand_areas:
le_x = face_obj["facial_area"]["left_eye"][0] le_x = face_obj["facial_area"]["left_eye"][0]
assert re_x < le_x, "right eye must be the right eye of the person" assert re_x < le_x, "right eye must be the right eye of the person"
assert isinstance(face_obj["confidence"], float) type_conf = type(face_obj["confidence"])
assert isinstance(
face_obj["confidence"], float
), f"confidence type must be float but it is {type_conf}"
assert face_obj["confidence"] <= 1 assert face_obj["confidence"] <= 1
plt.imshow(face) plt.imshow(face)