mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 19:45:21 +00:00
january 26 improvements
This commit is contained in:
parent
88814e6d2b
commit
36665a9e96
@ -2,7 +2,7 @@
|
||||
import os
|
||||
import warnings
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
|
||||
# 3rd party dependencies
|
||||
import numpy as np
|
||||
@ -65,34 +65,49 @@ def verify(
|
||||
Args:
|
||||
img1_path (str or np.ndarray): Path to the first image. Accepts exact image path
|
||||
as a string, numpy array (BGR), or base64 encoded images.
|
||||
|
||||
img2_path (str or np.ndarray): Path to the second image. Accepts exact image path
|
||||
as a string, numpy array (BGR), or base64 encoded images.
|
||||
|
||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv)
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
||||
align (bool): Flag to enable face alignment (default is True).
|
||||
|
||||
normalization (string): Normalize the input image before feeding it to the model.
|
||||
Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base)
|
||||
|
||||
Returns:
|
||||
result (dict): A dictionary containing verification results with following keys.
|
||||
|
||||
- 'verified' (bool): Indicates whether the images represent the same person (True)
|
||||
or different persons (False).
|
||||
|
||||
- 'distance' (float): The distance measure between the face vectors.
|
||||
A lower distance indicates higher similarity.
|
||||
|
||||
- 'max_threshold_to_verify' (float): The maximum threshold used for verification.
|
||||
If the distance is below this threshold, the images are considered a match.
|
||||
|
||||
- 'model' (str): The chosen face recognition model.
|
||||
|
||||
- 'similarity_metric' (str): The chosen similarity metric for measuring distances.
|
||||
|
||||
- 'facial_areas' (dict): Rectangular regions of interest for faces in both images.
|
||||
- 'img1': {'x': int, 'y': int, 'w': int, 'h': int}
|
||||
Region of interest for the first image.
|
||||
- 'img2': {'x': int, 'y': int, 'w': int, 'h': int}
|
||||
Region of interest for the second image.
|
||||
|
||||
- 'time' (float): Time taken for the verification process in seconds.
|
||||
"""
|
||||
|
||||
@ -122,37 +137,51 @@ def analyze(
|
||||
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format,
|
||||
or a base64 encoded image. If the source image contains multiple faces, the result will
|
||||
include information for each detected face.
|
||||
|
||||
actions (tuple): Attributes to analyze. The default is ('age', 'gender', 'emotion', 'race').
|
||||
You can exclude some of these attributes from the analysis if needed.
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||
|
||||
silent (boolean): Suppress or allow some log messages for a quieter analysis process
|
||||
(default is False).
|
||||
|
||||
Returns:
|
||||
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
|
||||
the analysis results for a detected face. Each dictionary in the list contains the
|
||||
following keys:
|
||||
|
||||
- 'region' (dict): Represents the rectangular region of the detected face in the image.
|
||||
- 'x': x-coordinate of the top-left corner of the face.
|
||||
- 'y': y-coordinate of the top-left corner of the face.
|
||||
- 'w': Width of the detected face region.
|
||||
- 'h': Height of the detected face region.
|
||||
|
||||
- 'age' (float): Estimated age of the detected face.
|
||||
|
||||
- 'face_confidence' (float): Confidence score for the detected face.
|
||||
Indicates the reliability of the face detection.
|
||||
|
||||
- 'dominant_gender' (str): The dominant gender in the detected face.
|
||||
Either "Man" or "Woman."
|
||||
Either "Man" or "Woman".
|
||||
|
||||
- 'gender' (dict): Confidence scores for each gender category.
|
||||
- 'Man': Confidence score for the male gender.
|
||||
- 'Woman': Confidence score for the female gender.
|
||||
|
||||
- 'dominant_emotion' (str): The dominant emotion in the detected face.
|
||||
Possible values include "sad," "angry," "surprise," "fear," "happy,"
|
||||
"disgust," and "neutral."
|
||||
"disgust," and "neutral"
|
||||
|
||||
- 'emotion' (dict): Confidence scores for each emotion category.
|
||||
- 'sad': Confidence score for sadness.
|
||||
- 'angry': Confidence score for anger.
|
||||
@ -161,9 +190,11 @@ def analyze(
|
||||
- 'happy': Confidence score for happiness.
|
||||
- 'disgust': Confidence score for disgust.
|
||||
- 'neutral': Confidence score for neutrality.
|
||||
|
||||
- 'dominant_race' (str): The dominant race in the detected face.
|
||||
Possible values include "indian," "asian," "latino hispanic,"
|
||||
"black," "middle eastern," and "white."
|
||||
|
||||
- 'race' (dict): Confidence scores for each race category.
|
||||
- 'indian': Confidence score for Indian ethnicity.
|
||||
- 'asian': Confidence score for Asian ethnicity.
|
||||
@ -190,6 +221,7 @@ def find(
|
||||
enforce_detection: bool = True,
|
||||
detector_backend: str = "opencv",
|
||||
align: bool = True,
|
||||
threshold: Optional[float] = None,
|
||||
normalization: str = "base",
|
||||
silent: bool = False,
|
||||
) -> List[pd.DataFrame]:
|
||||
@ -199,31 +231,51 @@ def find(
|
||||
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format,
|
||||
or a base64 encoded image. If the source image contains multiple faces, the result will
|
||||
include information for each detected face.
|
||||
|
||||
db_path (string): Path to the folder containing image files. All detected faces
|
||||
in the database will be considered in the decision-making process.
|
||||
|
||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||
|
||||
threshold (float): Specify a threshold to determine whether a pair represents the same
|
||||
person or different individuals. This threshold is used for comparing distances.
|
||||
If left unset, default pre-tuned threshold values will be applied based on the specified
|
||||
model name and distance metric (default is None).
|
||||
|
||||
normalization (string): Normalize the input image before feeding it to the model.
|
||||
Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base).
|
||||
|
||||
silent (boolean): Suppress or allow some log messages for a quieter analysis process
|
||||
(default is False).
|
||||
|
||||
Returns:
|
||||
results (List[pd.DataFrame]): A list of pandas dataframes. Each dataframe corresponds
|
||||
to the identity information for an individual detected in the source image.
|
||||
The DataFrame columns include:
|
||||
|
||||
- 'identity': Identity label of the detected individual.
|
||||
|
||||
- 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the
|
||||
target face in the database.
|
||||
|
||||
- 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the
|
||||
detected face in the source image.
|
||||
- '{model_name}_{distance_metric}': Similarity score between the faces based on the
|
||||
|
||||
- 'threshold': threshold to determine a pair whether same person or different persons
|
||||
|
||||
- 'distance': Similarity score between the faces based on the
|
||||
specified model and distance metric
|
||||
"""
|
||||
return recognition.find(
|
||||
@ -234,6 +286,7 @@ def find(
|
||||
enforce_detection=enforce_detection,
|
||||
detector_backend=detector_backend,
|
||||
align=align,
|
||||
threshold=threshold,
|
||||
normalization=normalization,
|
||||
silent=silent,
|
||||
)
|
||||
@ -254,27 +307,36 @@ def represent(
|
||||
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format,
|
||||
or a base64 encoded image. If the source image contains multiple faces, the result will
|
||||
include information for each detected face.
|
||||
|
||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face.).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Default is True. Set to False to avoid the exception for low-resolution images
|
||||
(default is True).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||
|
||||
normalization (string): Normalize the input image before feeding it to the model.
|
||||
Default is base. Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace
|
||||
(default is base).
|
||||
|
||||
Returns:
|
||||
results (List[Dict[str, Any]]): A list of dictionaries, each containing the
|
||||
following fields:
|
||||
|
||||
- embedding (np.array): Multidimensional vector representing facial features.
|
||||
The number of dimensions varies based on the reference model
|
||||
(e.g., FaceNet returns 128 dimensions, VGG-Face returns 4096 dimensions).
|
||||
|
||||
- facial_area (dict): Detected facial area by face detection in dictionary format.
|
||||
Contains 'x' and 'y' as the left-corner point, and 'w' and 'h'
|
||||
as the width and height. If `detector_backend` is set to 'skip', it represents
|
||||
the full image area and is nonsensical.
|
||||
|
||||
- face_confidence (float): Confidence score of face detection. If `detector_backend` is set
|
||||
to 'skip', the confidence will be 0 and is nonsensical.
|
||||
"""
|
||||
@ -355,19 +417,28 @@ def extract_faces(
|
||||
Args:
|
||||
img_path (str or np.ndarray): Path to the first image. Accepts exact image path
|
||||
as a string, numpy array (BGR), or base64 encoded images.
|
||||
|
||||
target_size (tuple): final shape of facial image. black pixels will be
|
||||
added to resize the image (default is (224, 224)).
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv)
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
|
||||
align (bool): Flag to enable face alignment (default is True).
|
||||
|
||||
grayscale (boolean): Flag to convert the image to grayscale before
|
||||
processing (default is False).
|
||||
|
||||
Returns:
|
||||
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains:
|
||||
|
||||
- "face" (np.ndarray): The detected face as a NumPy array.
|
||||
|
||||
- "facial_area" (List[float]): The detected face's regions represented as a list of floats.
|
||||
|
||||
- "confidence" (float): The confidence score associated with the detected face.
|
||||
"""
|
||||
|
||||
|
@ -12,7 +12,7 @@ from deepface.commons.logger import Logger
|
||||
|
||||
logger = Logger(module="detectors.SsdWrapper")
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
# pylint: disable=line-too-long, c-extension-no-member
|
||||
|
||||
|
||||
class SsdClient(Detector):
|
||||
|
@ -21,6 +21,13 @@ class YuNetClient(Detector):
|
||||
Returns:
|
||||
model (Any)
|
||||
"""
|
||||
|
||||
opencv_version = cv2.__version__.split(".")
|
||||
|
||||
if len(opencv_version) > 2 and int(opencv_version[0]) == 4 and int(opencv_version[1]) < 8:
|
||||
# min requirement: https://github.com/opencv/opencv_zoo/issues/172
|
||||
raise ValueError(f"YuNet requires opencv-python >= 4.8 but you have {cv2.__version__}")
|
||||
|
||||
# pylint: disable=C0301
|
||||
url = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
|
||||
file_name = "face_detection_yunet_2023mar.onnx"
|
||||
@ -67,7 +74,7 @@ class YuNetClient(Detector):
|
||||
"""
|
||||
# FaceDetector.detect_faces does not support score_threshold parameter.
|
||||
# We can set it via environment variable.
|
||||
score_threshold = os.environ.get("yunet_score_threshold", "0.9")
|
||||
score_threshold = float(os.environ.get("yunet_score_threshold", "0.9"))
|
||||
resp = []
|
||||
detected_face = None
|
||||
img_region = [0, 0, img.shape[1], img.shape[0]]
|
||||
|
@ -1,7 +1,7 @@
|
||||
# built-in dependencies
|
||||
import os
|
||||
import pickle
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
import time
|
||||
|
||||
# 3rd party dependencies
|
||||
@ -25,6 +25,7 @@ def find(
|
||||
enforce_detection: bool = True,
|
||||
detector_backend: str = "opencv",
|
||||
align: bool = True,
|
||||
threshold: Optional[float] = None,
|
||||
normalization: str = "base",
|
||||
silent: bool = False,
|
||||
) -> List[pd.DataFrame]:
|
||||
@ -53,6 +54,11 @@ def find(
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions.
|
||||
|
||||
threshold (float): Specify a threshold to determine whether a pair represents the same
|
||||
person or different individuals. This threshold is used for comparing distances.
|
||||
If left unset, default pre-tuned threshold values will be applied based on the specified
|
||||
model name and distance metric (default is None).
|
||||
|
||||
normalization (string): Normalize the input image before feeding it to the model.
|
||||
Default is base. Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace
|
||||
|
||||
@ -64,11 +70,16 @@ def find(
|
||||
The DataFrame columns include:
|
||||
|
||||
- 'identity': Identity label of the detected individual.
|
||||
|
||||
- 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the
|
||||
target face in the database.
|
||||
|
||||
- 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the
|
||||
detected face in the source image.
|
||||
- '{model_name}_{distance_metric}': Similarity score between the faces based on the
|
||||
|
||||
- 'threshold': threshold to determine a pair whether same person or different persons
|
||||
|
||||
- 'distance': Similarity score between the faces based on the
|
||||
specified model and distance metric
|
||||
"""
|
||||
|
||||
@ -248,16 +259,15 @@ def find(
|
||||
distances.append(distance)
|
||||
|
||||
# ---------------------------
|
||||
target_threshold = threshold or dst.findThreshold(model_name, distance_metric)
|
||||
|
||||
result_df[f"{model_name}_{distance_metric}"] = distances
|
||||
result_df["threshold"] = target_threshold
|
||||
result_df["distance"] = distances
|
||||
|
||||
threshold = dst.findThreshold(model_name, distance_metric)
|
||||
result_df = result_df.drop(columns=[f"{model_name}_representation"])
|
||||
# pylint: disable=unsubscriptable-object
|
||||
result_df = result_df[result_df[f"{model_name}_{distance_metric}"] <= threshold]
|
||||
result_df = result_df.sort_values(
|
||||
by=[f"{model_name}_{distance_metric}"], ascending=True
|
||||
).reset_index(drop=True)
|
||||
result_df = result_df[result_df["distance"] <= target_threshold]
|
||||
result_df = result_df.sort_values(by=["distance"], ascending=True).reset_index(drop=True)
|
||||
|
||||
resp_obj.append(result_df)
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -8,7 +8,7 @@ with open("requirements.txt", "r", encoding="utf-8") as f:
|
||||
|
||||
setuptools.setup(
|
||||
name="deepface",
|
||||
version="0.0.82",
|
||||
version="0.0.83",
|
||||
author="Sefik Ilkin Serengil",
|
||||
author_email="serengil@gmail.com",
|
||||
description="A Lightweight Face Recognition and Facial Attribute Analysis Framework (Age, Gender, Emotion, Race) for Python",
|
||||
|
@ -21,7 +21,7 @@ def test_find_with_exact_path():
|
||||
assert identity_df.shape[0] > 0
|
||||
|
||||
# validate reproducability
|
||||
assert identity_df["VGG-Face_cosine"].values[0] < threshold
|
||||
assert identity_df["distance"].values[0] < threshold
|
||||
|
||||
df = df[df["identity"] != img_path]
|
||||
logger.debug(df.head())
|
||||
@ -42,7 +42,7 @@ def test_find_with_array_input():
|
||||
assert identity_df.shape[0] > 0
|
||||
|
||||
# validate reproducability
|
||||
assert identity_df["VGG-Face_cosine"].values[0] < threshold
|
||||
assert identity_df["distance"].values[0] < threshold
|
||||
|
||||
df = df[df["identity"] != img_path]
|
||||
logger.debug(df.head())
|
||||
@ -65,7 +65,7 @@ def test_find_with_extracted_faces():
|
||||
assert identity_df.shape[0] > 0
|
||||
|
||||
# validate reproducability
|
||||
assert identity_df["VGG-Face_cosine"].values[0] < threshold
|
||||
assert identity_df["distance"].values[0] < threshold
|
||||
|
||||
df = df[df["identity"] != img_path]
|
||||
logger.debug(df.head())
|
||||
|
@ -20,7 +20,7 @@ model_names = [
|
||||
"SFace",
|
||||
]
|
||||
|
||||
detector_backends = ["opencv", "ssd", "dlib", "mtcnn", "retinaface"]
|
||||
detector_backends = ["opencv", "ssd", "dlib", "mtcnn", "retinaface", "yunet"]
|
||||
|
||||
|
||||
# verification
|
||||
|
Loading…
x
Reference in New Issue
Block a user