mirror of
https://github.com/serengil/deepface.git
synced 2025-06-02 09:30:06 +00:00
feat: Add Angular Distance as a Distance Metric
This commit is contained in:
parent
96a7b98f33
commit
9e334057df
@ -102,7 +102,7 @@ def verify(
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
'euclidean', 'euclidean_l2', 'angular' (default is cosine).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
@ -194,7 +194,7 @@ def analyze(
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
'euclidean', 'euclidean_l2', 'angular' (default is cosine).
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||
|
||||
@ -299,7 +299,7 @@ def find(
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
'euclidean', 'euclidean_l2', 'angular' (default is cosine).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
@ -479,7 +479,7 @@ def stream(
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
'euclidean', 'euclidean_l2', 'angular' (default is cosine).
|
||||
|
||||
enable_face_analysis (bool): Flag to enable face analysis (default is True).
|
||||
|
||||
|
@ -39,7 +39,7 @@ def analyze(
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
'euclidean', 'euclidean_l2', 'angular' (default is cosine).
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions (default is True).
|
||||
|
||||
|
@ -48,7 +48,7 @@ def find(
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2'.
|
||||
'euclidean', 'euclidean_l2', 'angular'.
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||
@ -481,7 +481,7 @@ def find_batched(
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2'.
|
||||
'euclidean', 'euclidean_l2', 'angular'.
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||
|
@ -51,7 +51,7 @@ def analysis(
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
'euclidean', 'euclidean_l2', 'angular' (default is cosine).
|
||||
|
||||
enable_face_analysis (bool): Flag to enable face analysis (default is True).
|
||||
|
||||
@ -223,7 +223,7 @@ def search_identity(
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
|
||||
'centerface' or 'skip' (default is opencv).
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
'euclidean', 'euclidean_l2', angular, (default is cosine).
|
||||
Returns:
|
||||
result (tuple): result consisting of following objects
|
||||
identified image path (str)
|
||||
@ -474,7 +474,7 @@ def perform_facial_recognition(
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s',
|
||||
'yolov11m', 'centerface' or 'skip' (default is opencv).
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
'euclidean', 'euclidean_l2', angular (default is cosine).
|
||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
Returns:
|
||||
|
@ -51,7 +51,7 @@ def verify(
|
||||
'centerface' or 'skip' (default is opencv)
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2' (default is cosine).
|
||||
'euclidean', 'euclidean_l2', angular (default is cosine).
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Set to False to avoid the exception for low-resolution images (default is True).
|
||||
@ -297,6 +297,45 @@ def find_cosine_distance(
|
||||
)
|
||||
return distances
|
||||
|
||||
def find_angular_distance(
|
||||
source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list]
|
||||
) -> Union[np.float64, np.ndarray]:
|
||||
"""
|
||||
Find angular distance between two vectors or batches of vectors.
|
||||
|
||||
Args:
|
||||
source_representation (np.ndarray or list): 1st vector or batch of vectors.
|
||||
test_representation (np.ndarray or list): 2nd vector or batch of vectors.
|
||||
|
||||
Returns:
|
||||
np.float64 or np.ndarray: angular distance(s).
|
||||
Returns a np.float64 for single embeddings and np.ndarray for batch embeddings.
|
||||
"""
|
||||
|
||||
# calculate cosine similarity first
|
||||
# then convert to angular distance
|
||||
source_representation = np.asarray(source_representation)
|
||||
test_representation = np.asarray(test_representation)
|
||||
|
||||
if source_representation.ndim == 1 and test_representation.ndim == 1:
|
||||
# single embedding
|
||||
dot_product = np.dot(source_representation, test_representation)
|
||||
source_norm = np.linalg.norm(source_representation)
|
||||
test_norm = np.linalg.norm(test_representation)
|
||||
similarity = dot_product / (source_norm * test_norm)
|
||||
distances = np.arccos(similarity) / np.pi
|
||||
elif source_representation.ndim == 2 and test_representation.ndim == 2:
|
||||
# list of embeddings (batch)
|
||||
source_normed = l2_normalize(source_representation, axis=1) # (N, D)
|
||||
test_normed = l2_normalize(test_representation, axis=1) # (M, D)
|
||||
similarity = np.dot(test_normed, source_normed.T) # (M, N)
|
||||
distances = np.arccos(similarity) / np.pi
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Embeddings must be 1D or 2D, but received "
|
||||
f"source shape: {source_representation.shape}, test shape: {test_representation.shape}"
|
||||
)
|
||||
return distances
|
||||
|
||||
def find_euclidean_distance(
|
||||
source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list]
|
||||
@ -362,7 +401,7 @@ def find_distance(
|
||||
alpha_embedding (np.ndarray or list): 1st vector or batch of vectors.
|
||||
beta_embedding (np.ndarray or list): 2nd vector or batch of vectors.
|
||||
distance_metric (str): The type of distance to compute
|
||||
('cosine', 'euclidean', or 'euclidean_l2').
|
||||
('cosine', 'euclidean', 'euclidean_l2', or 'angular').
|
||||
|
||||
Returns:
|
||||
np.float64 or np.ndarray: The calculated distance(s).
|
||||
@ -380,6 +419,8 @@ def find_distance(
|
||||
|
||||
if distance_metric == "cosine":
|
||||
distance = find_cosine_distance(alpha_embedding, beta_embedding)
|
||||
elif distance_metric == "angular":
|
||||
distance = find_angular_distance(alpha_embedding, beta_embedding)
|
||||
elif distance_metric == "euclidean":
|
||||
distance = find_euclidean_distance(alpha_embedding, beta_embedding)
|
||||
elif distance_metric == "euclidean_l2":
|
||||
@ -399,31 +440,32 @@ def find_threshold(model_name: str, distance_metric: str) -> float:
|
||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
distance_metric (str): distance metric name. Options are cosine, euclidean
|
||||
and euclidean_l2.
|
||||
euclidean_l2 and angular.
|
||||
Returns:
|
||||
threshold (float): threshold value for that model name and distance metric
|
||||
pair. Distances less than this threshold will be classified same person.
|
||||
"""
|
||||
|
||||
base_threshold = {"cosine": 0.40, "euclidean": 0.55, "euclidean_l2": 0.75}
|
||||
base_threshold = {"cosine": 0.40, "euclidean": 0.55, "euclidean_l2": 0.75, "angular": 0.37}
|
||||
|
||||
thresholds = {
|
||||
# "VGG-Face": {"cosine": 0.40, "euclidean": 0.60, "euclidean_l2": 0.86}, # 2622d
|
||||
# "VGG-Face": {"cosine": 0.40, "euclidean": 0.60, "euclidean_l2": 0.86, "angular": 0.37}, # 2622d
|
||||
"VGG-Face": {
|
||||
"cosine": 0.68,
|
||||
"euclidean": 1.17,
|
||||
"euclidean_l2": 1.17,
|
||||
"angular": 0.43,
|
||||
}, # 4096d - tuned with LFW
|
||||
"Facenet": {"cosine": 0.40, "euclidean": 10, "euclidean_l2": 0.80},
|
||||
"Facenet512": {"cosine": 0.30, "euclidean": 23.56, "euclidean_l2": 1.04},
|
||||
"ArcFace": {"cosine": 0.68, "euclidean": 4.15, "euclidean_l2": 1.13},
|
||||
"Dlib": {"cosine": 0.07, "euclidean": 0.6, "euclidean_l2": 0.4},
|
||||
"SFace": {"cosine": 0.593, "euclidean": 10.734, "euclidean_l2": 1.055},
|
||||
"OpenFace": {"cosine": 0.10, "euclidean": 0.55, "euclidean_l2": 0.55},
|
||||
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64},
|
||||
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17},
|
||||
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10},
|
||||
"Buffalo_L": {"cosine": 0.55, "euclidean": 0.6, "euclidean_l2": 1.1},
|
||||
"Facenet": {"cosine": 0.40, "euclidean": 10, "euclidean_l2": 0.80, "angular": 0.47},
|
||||
"Facenet512": {"cosine": 0.30, "euclidean": 23.56, "euclidean_l2": 1.04, "angular": 0.49},
|
||||
"ArcFace": {"cosine": 0.68, "euclidean": 4.15, "euclidean_l2": 1.13, "angular": 0.43},
|
||||
"Dlib": {"cosine": 0.07, "euclidean": 0.6, "euclidean_l2": 0.4, "angular": 0.50},
|
||||
"SFace": {"cosine": 0.593, "euclidean": 10.734, "euclidean_l2": 1.055, "angular": 0.445},
|
||||
"OpenFace": {"cosine": 0.10, "euclidean": 0.55, "euclidean_l2": 0.55, "angular": 0.50},
|
||||
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64, "angular": 0.49},
|
||||
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17, "angular": 0.50},
|
||||
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10, "angular": 0.43},
|
||||
"Buffalo_L": {"cosine": 0.55, "euclidean": 0.6, "euclidean_l2": 1.1, "angular": 0.45},
|
||||
}
|
||||
|
||||
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)
|
||||
|
@ -9,7 +9,7 @@ from deepface.commons.logger import Logger
|
||||
logger = Logger()
|
||||
|
||||
models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet"]
|
||||
metrics = ["cosine", "euclidean", "euclidean_l2"]
|
||||
metrics = ["cosine", "euclidean", "euclidean_l2", "angular"]
|
||||
detectors = ["opencv", "mtcnn"]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user