january 26 improvements

This commit is contained in:
Sefik Ilkin Serengil 2024-01-26 17:52:55 +00:00
parent 88814e6d2b
commit 36665a9e96
7 changed files with 109 additions and 21 deletions

View File

@ -2,7 +2,7 @@
import os import os
import warnings import warnings
import logging import logging
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union, Optional
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -65,34 +65,49 @@ def verify(
Args: Args:
img1_path (str or np.ndarray): Path to the first image. Accepts exact image path img1_path (str or np.ndarray): Path to the first image. Accepts exact image path
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), or base64 encoded images.
img2_path (str or np.ndarray): Path to the second image. Accepts exact image path img2_path (str or np.ndarray): Path to the second image. Accepts exact image path
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), or base64 encoded images.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv) 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
align (bool): Flag to enable face alignment (default is True). align (bool): Flag to enable face alignment (default is True).
normalization (string): Normalize the input image before feeding it to the model. normalization (string): Normalize the input image before feeding it to the model.
Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base) Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base)
Returns: Returns:
result (dict): A dictionary containing verification results with following keys. result (dict): A dictionary containing verification results with following keys.
- 'verified' (bool): Indicates whether the images represent the same person (True) - 'verified' (bool): Indicates whether the images represent the same person (True)
or different persons (False). or different persons (False).
- 'distance' (float): The distance measure between the face vectors. - 'distance' (float): The distance measure between the face vectors.
A lower distance indicates higher similarity. A lower distance indicates higher similarity.
- 'max_threshold_to_verify' (float): The maximum threshold used for verification. - 'max_threshold_to_verify' (float): The maximum threshold used for verification.
If the distance is below this threshold, the images are considered a match. If the distance is below this threshold, the images are considered a match.
- 'model' (str): The chosen face recognition model. - 'model' (str): The chosen face recognition model.
- 'similarity_metric' (str): The chosen similarity metric for measuring distances. - 'similarity_metric' (str): The chosen similarity metric for measuring distances.
- 'facial_areas' (dict): Rectangular regions of interest for faces in both images. - 'facial_areas' (dict): Rectangular regions of interest for faces in both images.
- 'img1': {'x': int, 'y': int, 'w': int, 'h': int} - 'img1': {'x': int, 'y': int, 'w': int, 'h': int}
Region of interest for the first image. Region of interest for the first image.
- 'img2': {'x': int, 'y': int, 'w': int, 'h': int} - 'img2': {'x': int, 'y': int, 'w': int, 'h': int}
Region of interest for the second image. Region of interest for the second image.
- 'time' (float): Time taken for the verification process in seconds. - 'time' (float): Time taken for the verification process in seconds.
""" """
@ -122,37 +137,51 @@ def analyze(
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format,
or a base64 encoded image. If the source image contains multiple faces, the result will or a base64 encoded image. If the source image contains multiple faces, the result will
include information for each detected face. include information for each detected face.
actions (tuple): Attributes to analyze. The default is ('age', 'gender', 'emotion', 'race'). actions (tuple): Attributes to analyze. The default is ('age', 'gender', 'emotion', 'race').
You can exclude some of these attributes from the analysis if needed. You can exclude some of these attributes from the analysis if needed.
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
align (boolean): Perform alignment based on the eye positions (default is True). align (boolean): Perform alignment based on the eye positions (default is True).
silent (boolean): Suppress or allow some log messages for a quieter analysis process silent (boolean): Suppress or allow some log messages for a quieter analysis process
(default is False). (default is False).
Returns: Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
the analysis results for a detected face. Each dictionary in the list contains the the analysis results for a detected face. Each dictionary in the list contains the
following keys: following keys:
- 'region' (dict): Represents the rectangular region of the detected face in the image. - 'region' (dict): Represents the rectangular region of the detected face in the image.
- 'x': x-coordinate of the top-left corner of the face. - 'x': x-coordinate of the top-left corner of the face.
- 'y': y-coordinate of the top-left corner of the face. - 'y': y-coordinate of the top-left corner of the face.
- 'w': Width of the detected face region. - 'w': Width of the detected face region.
- 'h': Height of the detected face region. - 'h': Height of the detected face region.
- 'age' (float): Estimated age of the detected face. - 'age' (float): Estimated age of the detected face.
- 'face_confidence' (float): Confidence score for the detected face. - 'face_confidence' (float): Confidence score for the detected face.
Indicates the reliability of the face detection. Indicates the reliability of the face detection.
- 'dominant_gender' (str): The dominant gender in the detected face. - 'dominant_gender' (str): The dominant gender in the detected face.
Either "Man" or "Woman." Either "Man" or "Woman".
- 'gender' (dict): Confidence scores for each gender category. - 'gender' (dict): Confidence scores for each gender category.
- 'Man': Confidence score for the male gender. - 'Man': Confidence score for the male gender.
- 'Woman': Confidence score for the female gender. - 'Woman': Confidence score for the female gender.
- 'dominant_emotion' (str): The dominant emotion in the detected face. - 'dominant_emotion' (str): The dominant emotion in the detected face.
Possible values include "sad," "angry," "surprise," "fear," "happy," Possible values include "sad," "angry," "surprise," "fear," "happy,"
"disgust," and "neutral." "disgust," and "neutral"
- 'emotion' (dict): Confidence scores for each emotion category. - 'emotion' (dict): Confidence scores for each emotion category.
- 'sad': Confidence score for sadness. - 'sad': Confidence score for sadness.
- 'angry': Confidence score for anger. - 'angry': Confidence score for anger.
@ -161,9 +190,11 @@ def analyze(
- 'happy': Confidence score for happiness. - 'happy': Confidence score for happiness.
- 'disgust': Confidence score for disgust. - 'disgust': Confidence score for disgust.
- 'neutral': Confidence score for neutrality. - 'neutral': Confidence score for neutrality.
- 'dominant_race' (str): The dominant race in the detected face. - 'dominant_race' (str): The dominant race in the detected face.
Possible values include "indian," "asian," "latino hispanic," Possible values include "indian," "asian," "latino hispanic,"
"black," "middle eastern," and "white." "black," "middle eastern," and "white."
- 'race' (dict): Confidence scores for each race category. - 'race' (dict): Confidence scores for each race category.
- 'indian': Confidence score for Indian ethnicity. - 'indian': Confidence score for Indian ethnicity.
- 'asian': Confidence score for Asian ethnicity. - 'asian': Confidence score for Asian ethnicity.
@ -190,6 +221,7 @@ def find(
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
align: bool = True, align: bool = True,
threshold: Optional[float] = None,
normalization: str = "base", normalization: str = "base",
silent: bool = False, silent: bool = False,
) -> List[pd.DataFrame]: ) -> List[pd.DataFrame]:
@ -199,31 +231,51 @@ def find(
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format,
or a base64 encoded image. If the source image contains multiple faces, the result will or a base64 encoded image. If the source image contains multiple faces, the result will
include information for each detected face. include information for each detected face.
db_path (string): Path to the folder containing image files. All detected faces db_path (string): Path to the folder containing image files. All detected faces
in the database will be considered in the decision-making process. in the database will be considered in the decision-making process.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2' (default is cosine).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
align (boolean): Perform alignment based on the eye positions (default is True). align (boolean): Perform alignment based on the eye positions (default is True).
threshold (float): Specify a threshold to determine whether a pair represents the same
person or different individuals. This threshold is used for comparing distances.
If left unset, default pre-tuned threshold values will be applied based on the specified
model name and distance metric (default is None).
normalization (string): Normalize the input image before feeding it to the model. normalization (string): Normalize the input image before feeding it to the model.
Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base). Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base).
silent (boolean): Suppress or allow some log messages for a quieter analysis process silent (boolean): Suppress or allow some log messages for a quieter analysis process
(default is False). (default is False).
Returns: Returns:
results (List[pd.DataFrame]): A list of pandas dataframes. Each dataframe corresponds results (List[pd.DataFrame]): A list of pandas dataframes. Each dataframe corresponds
to the identity information for an individual detected in the source image. to the identity information for an individual detected in the source image.
The DataFrame columns include: The DataFrame columns include:
- 'identity': Identity label of the detected individual. - 'identity': Identity label of the detected individual.
- 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the - 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the
target face in the database. target face in the database.
- 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the - 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the
detected face in the source image. detected face in the source image.
- '{model_name}_{distance_metric}': Similarity score between the faces based on the
- 'threshold': threshold to determine a pair whether same person or different persons
- 'distance': Similarity score between the faces based on the
specified model and distance metric specified model and distance metric
""" """
return recognition.find( return recognition.find(
@ -234,6 +286,7 @@ def find(
enforce_detection=enforce_detection, enforce_detection=enforce_detection,
detector_backend=detector_backend, detector_backend=detector_backend,
align=align, align=align,
threshold=threshold,
normalization=normalization, normalization=normalization,
silent=silent, silent=silent,
) )
@ -254,27 +307,36 @@ def represent(
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format,
or a base64 encoded image. If the source image contains multiple faces, the result will or a base64 encoded image. If the source image contains multiple faces, the result will
include information for each detected face. include information for each detected face.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face.). OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face.).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Default is True. Set to False to avoid the exception for low-resolution images Default is True. Set to False to avoid the exception for low-resolution images
(default is True). (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv). 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
align (boolean): Perform alignment based on the eye positions (default is True). align (boolean): Perform alignment based on the eye positions (default is True).
normalization (string): Normalize the input image before feeding it to the model. normalization (string): Normalize the input image before feeding it to the model.
Default is base. Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace Default is base. Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace
(default is base). (default is base).
Returns: Returns:
results (List[Dict[str, Any]]): A list of dictionaries, each containing the results (List[Dict[str, Any]]): A list of dictionaries, each containing the
following fields: following fields:
- embedding (np.array): Multidimensional vector representing facial features. - embedding (np.array): Multidimensional vector representing facial features.
The number of dimensions varies based on the reference model The number of dimensions varies based on the reference model
(e.g., FaceNet returns 128 dimensions, VGG-Face returns 4096 dimensions). (e.g., FaceNet returns 128 dimensions, VGG-Face returns 4096 dimensions).
- facial_area (dict): Detected facial area by face detection in dictionary format. - facial_area (dict): Detected facial area by face detection in dictionary format.
Contains 'x' and 'y' as the left-corner point, and 'w' and 'h' Contains 'x' and 'y' as the left-corner point, and 'w' and 'h'
as the width and height. If `detector_backend` is set to 'skip', it represents as the width and height. If `detector_backend` is set to 'skip', it represents
the full image area and is nonsensical. the full image area and is nonsensical.
- face_confidence (float): Confidence score of face detection. If `detector_backend` is set - face_confidence (float): Confidence score of face detection. If `detector_backend` is set
to 'skip', the confidence will be 0 and is nonsensical. to 'skip', the confidence will be 0 and is nonsensical.
""" """
@ -355,19 +417,28 @@ def extract_faces(
Args: Args:
img_path (str or np.ndarray): Path to the first image. Accepts exact image path img_path (str or np.ndarray): Path to the first image. Accepts exact image path
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), or base64 encoded images.
target_size (tuple): final shape of facial image. black pixels will be target_size (tuple): final shape of facial image. black pixels will be
added to resize the image (default is (224, 224)). added to resize the image (default is (224, 224)).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv) 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
align (bool): Flag to enable face alignment (default is True). align (bool): Flag to enable face alignment (default is True).
grayscale (boolean): Flag to convert the image to grayscale before grayscale (boolean): Flag to convert the image to grayscale before
processing (default is False). processing (default is False).
Returns: Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains: results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains:
- "face" (np.ndarray): The detected face as a NumPy array. - "face" (np.ndarray): The detected face as a NumPy array.
- "facial_area" (List[float]): The detected face's regions represented as a list of floats. - "facial_area" (List[float]): The detected face's regions represented as a list of floats.
- "confidence" (float): The confidence score associated with the detected face. - "confidence" (float): The confidence score associated with the detected face.
""" """

View File

@ -12,7 +12,7 @@ from deepface.commons.logger import Logger
logger = Logger(module="detectors.SsdWrapper") logger = Logger(module="detectors.SsdWrapper")
# pylint: disable=line-too-long # pylint: disable=line-too-long, c-extension-no-member
class SsdClient(Detector): class SsdClient(Detector):

View File

@ -21,6 +21,13 @@ class YuNetClient(Detector):
Returns: Returns:
model (Any) model (Any)
""" """
opencv_version = cv2.__version__.split(".")
if len(opencv_version) > 2 and int(opencv_version[0]) == 4 and int(opencv_version[1]) < 8:
# min requirement: https://github.com/opencv/opencv_zoo/issues/172
raise ValueError(f"YuNet requires opencv-python >= 4.8 but you have {cv2.__version__}")
# pylint: disable=C0301 # pylint: disable=C0301
url = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx" url = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
file_name = "face_detection_yunet_2023mar.onnx" file_name = "face_detection_yunet_2023mar.onnx"
@ -67,7 +74,7 @@ class YuNetClient(Detector):
""" """
# FaceDetector.detect_faces does not support score_threshold parameter. # FaceDetector.detect_faces does not support score_threshold parameter.
# We can set it via environment variable. # We can set it via environment variable.
score_threshold = os.environ.get("yunet_score_threshold", "0.9") score_threshold = float(os.environ.get("yunet_score_threshold", "0.9"))
resp = [] resp = []
detected_face = None detected_face = None
img_region = [0, 0, img.shape[1], img.shape[0]] img_region = [0, 0, img.shape[1], img.shape[0]]

View File

@ -1,7 +1,7 @@
# built-in dependencies # built-in dependencies
import os import os
import pickle import pickle
from typing import List, Union from typing import List, Union, Optional
import time import time
# 3rd party dependencies # 3rd party dependencies
@ -25,6 +25,7 @@ def find(
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
align: bool = True, align: bool = True,
threshold: Optional[float] = None,
normalization: str = "base", normalization: str = "base",
silent: bool = False, silent: bool = False,
) -> List[pd.DataFrame]: ) -> List[pd.DataFrame]:
@ -53,6 +54,11 @@ def find(
align (boolean): Perform alignment based on the eye positions. align (boolean): Perform alignment based on the eye positions.
threshold (float): Specify a threshold to determine whether a pair represents the same
person or different individuals. This threshold is used for comparing distances.
If left unset, default pre-tuned threshold values will be applied based on the specified
model name and distance metric (default is None).
normalization (string): Normalize the input image before feeding it to the model. normalization (string): Normalize the input image before feeding it to the model.
Default is base. Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace Default is base. Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace
@ -64,11 +70,16 @@ def find(
The DataFrame columns include: The DataFrame columns include:
- 'identity': Identity label of the detected individual. - 'identity': Identity label of the detected individual.
- 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the - 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the
target face in the database. target face in the database.
- 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the - 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the
detected face in the source image. detected face in the source image.
- '{model_name}_{distance_metric}': Similarity score between the faces based on the
- 'threshold': threshold to determine a pair whether same person or different persons
- 'distance': Similarity score between the faces based on the
specified model and distance metric specified model and distance metric
""" """
@ -248,16 +259,15 @@ def find(
distances.append(distance) distances.append(distance)
# --------------------------- # ---------------------------
target_threshold = threshold or dst.findThreshold(model_name, distance_metric)
result_df[f"{model_name}_{distance_metric}"] = distances result_df["threshold"] = target_threshold
result_df["distance"] = distances
threshold = dst.findThreshold(model_name, distance_metric)
result_df = result_df.drop(columns=[f"{model_name}_representation"]) result_df = result_df.drop(columns=[f"{model_name}_representation"])
# pylint: disable=unsubscriptable-object # pylint: disable=unsubscriptable-object
result_df = result_df[result_df[f"{model_name}_{distance_metric}"] <= threshold] result_df = result_df[result_df["distance"] <= target_threshold]
result_df = result_df.sort_values( result_df = result_df.sort_values(by=["distance"], ascending=True).reset_index(drop=True)
by=[f"{model_name}_{distance_metric}"], ascending=True
).reset_index(drop=True)
resp_obj.append(result_df) resp_obj.append(result_df)

View File

@ -8,7 +8,7 @@ with open("requirements.txt", "r", encoding="utf-8") as f:
setuptools.setup( setuptools.setup(
name="deepface", name="deepface",
version="0.0.82", version="0.0.83",
author="Sefik Ilkin Serengil", author="Sefik Ilkin Serengil",
author_email="serengil@gmail.com", author_email="serengil@gmail.com",
description="A Lightweight Face Recognition and Facial Attribute Analysis Framework (Age, Gender, Emotion, Race) for Python", description="A Lightweight Face Recognition and Facial Attribute Analysis Framework (Age, Gender, Emotion, Race) for Python",

View File

@ -21,7 +21,7 @@ def test_find_with_exact_path():
assert identity_df.shape[0] > 0 assert identity_df.shape[0] > 0
# validate reproducability # validate reproducability
assert identity_df["VGG-Face_cosine"].values[0] < threshold assert identity_df["distance"].values[0] < threshold
df = df[df["identity"] != img_path] df = df[df["identity"] != img_path]
logger.debug(df.head()) logger.debug(df.head())
@ -42,7 +42,7 @@ def test_find_with_array_input():
assert identity_df.shape[0] > 0 assert identity_df.shape[0] > 0
# validate reproducability # validate reproducability
assert identity_df["VGG-Face_cosine"].values[0] < threshold assert identity_df["distance"].values[0] < threshold
df = df[df["identity"] != img_path] df = df[df["identity"] != img_path]
logger.debug(df.head()) logger.debug(df.head())
@ -65,7 +65,7 @@ def test_find_with_extracted_faces():
assert identity_df.shape[0] > 0 assert identity_df.shape[0] > 0
# validate reproducability # validate reproducability
assert identity_df["VGG-Face_cosine"].values[0] < threshold assert identity_df["distance"].values[0] < threshold
df = df[df["identity"] != img_path] df = df[df["identity"] != img_path]
logger.debug(df.head()) logger.debug(df.head())

View File

@ -20,7 +20,7 @@ model_names = [
"SFace", "SFace",
] ]
detector_backends = ["opencv", "ssd", "dlib", "mtcnn", "retinaface"] detector_backends = ["opencv", "ssd", "dlib", "mtcnn", "retinaface", "yunet"]
# verification # verification