Merge pull request #1 from haddyadnan/add_angular_distance

feat: Add Angular Distance as a Distance Metric
This commit is contained in:
haddyadnan 2025-04-07 22:53:32 +03:00 committed by GitHub
commit 9523214c1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 70 additions and 28 deletions

View File

@ -150,10 +150,10 @@ Conducting experiments with those models within DeepFace may reveal disparities
Face recognition models are regular [convolutional neural networks](https://sefiks.com/2018/03/23/convolutional-autoencoder-clustering-images-with-neural-networks/) and they are responsible to represent faces as vectors. We expect that a face pair of same person should be [more similar](https://sefiks.com/2020/05/22/fine-tuning-the-threshold-in-face-recognition/) than a face pair of different persons. Face recognition models are regular [convolutional neural networks](https://sefiks.com/2018/03/23/convolutional-autoencoder-clustering-images-with-neural-networks/) and they are responsible to represent faces as vectors. We expect that a face pair of same person should be [more similar](https://sefiks.com/2020/05/22/fine-tuning-the-threshold-in-face-recognition/) than a face pair of different persons.
Similarity could be calculated by different metrics such as [Cosine Similarity](https://sefiks.com/2018/08/13/cosine-similarity-in-machine-learning/), Euclidean Distance or L2 normalized Euclidean. The default configuration uses cosine similarity. According to [experiments](https://github.com/serengil/deepface/tree/master/benchmarks), no distance metric is overperforming than other. Similarity could be calculated by different metrics such as [Cosine Similarity](https://sefiks.com/2018/08/13/cosine-similarity-in-machine-learning/), Angular Distance, Euclidean Distance or L2 normalized Euclidean. The default configuration uses cosine similarity. According to [experiments](https://github.com/serengil/deepface/tree/master/benchmarks), no distance metric is overperforming than other.
```python ```python
metrics = ["cosine", "euclidean", "euclidean_l2"] metrics = ["cosine", "euclidean", "euclidean_l2", 'angular']
result = DeepFace.verify( result = DeepFace.verify(
img1_path = "img1.jpg", img2_path = "img2.jpg", distance_metric = metrics[1] img1_path = "img1.jpg", img2_path = "img2.jpg", distance_metric = metrics[1]

View File

@ -102,7 +102,7 @@ def verify(
'centerface' or 'skip' (default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2', 'angular' (default is cosine).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
@ -194,7 +194,7 @@ def analyze(
'centerface' or 'skip' (default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2', 'angular' (default is cosine).
align (boolean): Perform alignment based on the eye positions (default is True). align (boolean): Perform alignment based on the eye positions (default is True).
@ -299,7 +299,7 @@ def find(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2', 'angular' (default is cosine).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
@ -479,7 +479,7 @@ def stream(
'centerface' or 'skip' (default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2', 'angular' (default is cosine).
enable_face_analysis (bool): Flag to enable face analysis (default is True). enable_face_analysis (bool): Flag to enable face analysis (default is True).

View File

@ -39,7 +39,7 @@ def analyze(
'centerface' or 'skip' (default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2', 'angular' (default is cosine).
align (boolean): Perform alignment based on the eye positions (default is True). align (boolean): Perform alignment based on the eye positions (default is True).

View File

@ -48,7 +48,7 @@ def find(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2'. 'euclidean', 'euclidean_l2', 'angular'.
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Default is True. Set to False to avoid the exception for low-resolution images. Default is True. Set to False to avoid the exception for low-resolution images.
@ -481,7 +481,7 @@ def find_batched(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2'. 'euclidean', 'euclidean_l2', 'angular'.
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Default is True. Set to False to avoid the exception for low-resolution images. Default is True. Set to False to avoid the exception for low-resolution images.

View File

@ -51,7 +51,7 @@ def analysis(
'centerface' or 'skip' (default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2', 'angular' (default is cosine).
enable_face_analysis (bool): Flag to enable face analysis (default is True). enable_face_analysis (bool): Flag to enable face analysis (default is True).
@ -223,7 +223,7 @@ def search_identity(
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',
'centerface' or 'skip' (default is opencv). 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2', angular, (default is cosine).
Returns: Returns:
result (tuple): result consisting of following objects result (tuple): result consisting of following objects
identified image path (str) identified image path (str)
@ -474,7 +474,7 @@ def perform_facial_recognition(
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s',
'yolov11m', 'centerface' or 'skip' (default is opencv). 'yolov11m', 'centerface' or 'skip' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2', angular (default is cosine).
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
Returns: Returns:

View File

@ -51,7 +51,7 @@ def verify(
'centerface' or 'skip' (default is opencv) 'centerface' or 'skip' (default is opencv)
distance_metric (string): Metric for measuring similarity. Options: 'cosine', distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine). 'euclidean', 'euclidean_l2', angular (default is cosine).
enforce_detection (boolean): If no face is detected in an image, raise an exception. enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True). Set to False to avoid the exception for low-resolution images (default is True).
@ -297,6 +297,45 @@ def find_cosine_distance(
) )
return distances return distances
def find_angular_distance(
source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list]
) -> Union[np.float64, np.ndarray]:
"""
Find angular distance between two vectors or batches of vectors.
Args:
source_representation (np.ndarray or list): 1st vector or batch of vectors.
test_representation (np.ndarray or list): 2nd vector or batch of vectors.
Returns:
np.float64 or np.ndarray: angular distance(s).
Returns a np.float64 for single embeddings and np.ndarray for batch embeddings.
"""
# calculate cosine similarity first
# then convert to angular distance
source_representation = np.asarray(source_representation)
test_representation = np.asarray(test_representation)
if source_representation.ndim == 1 and test_representation.ndim == 1:
# single embedding
dot_product = np.dot(source_representation, test_representation)
source_norm = np.linalg.norm(source_representation)
test_norm = np.linalg.norm(test_representation)
similarity = dot_product / (source_norm * test_norm)
distances = np.arccos(similarity) / np.pi
elif source_representation.ndim == 2 and test_representation.ndim == 2:
# list of embeddings (batch)
source_normed = l2_normalize(source_representation, axis=1) # (N, D)
test_normed = l2_normalize(test_representation, axis=1) # (M, D)
similarity = np.dot(test_normed, source_normed.T) # (M, N)
distances = np.arccos(similarity) / np.pi
else:
raise ValueError(
f"Embeddings must be 1D or 2D, but received "
f"source shape: {source_representation.shape}, test shape: {test_representation.shape}"
)
return distances
def find_euclidean_distance( def find_euclidean_distance(
source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list] source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list]
@ -362,7 +401,7 @@ def find_distance(
alpha_embedding (np.ndarray or list): 1st vector or batch of vectors. alpha_embedding (np.ndarray or list): 1st vector or batch of vectors.
beta_embedding (np.ndarray or list): 2nd vector or batch of vectors. beta_embedding (np.ndarray or list): 2nd vector or batch of vectors.
distance_metric (str): The type of distance to compute distance_metric (str): The type of distance to compute
('cosine', 'euclidean', or 'euclidean_l2'). ('cosine', 'euclidean', 'euclidean_l2', or 'angular').
Returns: Returns:
np.float64 or np.ndarray: The calculated distance(s). np.float64 or np.ndarray: The calculated distance(s).
@ -380,6 +419,8 @@ def find_distance(
if distance_metric == "cosine": if distance_metric == "cosine":
distance = find_cosine_distance(alpha_embedding, beta_embedding) distance = find_cosine_distance(alpha_embedding, beta_embedding)
elif distance_metric == "angular":
distance = find_angular_distance(alpha_embedding, beta_embedding)
elif distance_metric == "euclidean": elif distance_metric == "euclidean":
distance = find_euclidean_distance(alpha_embedding, beta_embedding) distance = find_euclidean_distance(alpha_embedding, beta_embedding)
elif distance_metric == "euclidean_l2": elif distance_metric == "euclidean_l2":
@ -399,31 +440,32 @@ def find_threshold(model_name: str, distance_metric: str) -> float:
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
distance_metric (str): distance metric name. Options are cosine, euclidean distance_metric (str): distance metric name. Options are cosine, euclidean
and euclidean_l2. euclidean_l2 and angular.
Returns: Returns:
threshold (float): threshold value for that model name and distance metric threshold (float): threshold value for that model name and distance metric
pair. Distances less than this threshold will be classified same person. pair. Distances less than this threshold will be classified same person.
""" """
base_threshold = {"cosine": 0.40, "euclidean": 0.55, "euclidean_l2": 0.75} base_threshold = {"cosine": 0.40, "euclidean": 0.55, "euclidean_l2": 0.75, "angular": 0.37}
thresholds = { thresholds = {
# "VGG-Face": {"cosine": 0.40, "euclidean": 0.60, "euclidean_l2": 0.86}, # 2622d # "VGG-Face": {"cosine": 0.40, "euclidean": 0.60, "euclidean_l2": 0.86, "angular": 0.37}, # 2622d
"VGG-Face": { "VGG-Face": {
"cosine": 0.68, "cosine": 0.68,
"euclidean": 1.17, "euclidean": 1.17,
"euclidean_l2": 1.17, "euclidean_l2": 1.17,
"angular": 0.43,
}, # 4096d - tuned with LFW }, # 4096d - tuned with LFW
"Facenet": {"cosine": 0.40, "euclidean": 10, "euclidean_l2": 0.80}, "Facenet": {"cosine": 0.40, "euclidean": 10, "euclidean_l2": 0.80, "angular": 0.47},
"Facenet512": {"cosine": 0.30, "euclidean": 23.56, "euclidean_l2": 1.04}, "Facenet512": {"cosine": 0.30, "euclidean": 23.56, "euclidean_l2": 1.04, "angular": 0.49},
"ArcFace": {"cosine": 0.68, "euclidean": 4.15, "euclidean_l2": 1.13}, "ArcFace": {"cosine": 0.68, "euclidean": 4.15, "euclidean_l2": 1.13, "angular": 0.43},
"Dlib": {"cosine": 0.07, "euclidean": 0.6, "euclidean_l2": 0.4}, "Dlib": {"cosine": 0.07, "euclidean": 0.6, "euclidean_l2": 0.4, "angular": 0.50},
"SFace": {"cosine": 0.593, "euclidean": 10.734, "euclidean_l2": 1.055}, "SFace": {"cosine": 0.593, "euclidean": 10.734, "euclidean_l2": 1.055, "angular": 0.445},
"OpenFace": {"cosine": 0.10, "euclidean": 0.55, "euclidean_l2": 0.55}, "OpenFace": {"cosine": 0.10, "euclidean": 0.55, "euclidean_l2": 0.55, "angular": 0.50},
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64}, "DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64, "angular": 0.49},
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17}, "DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17, "angular": 0.50},
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10}, "GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10, "angular": 0.43},
"Buffalo_L": {"cosine": 0.55, "euclidean": 0.6, "euclidean_l2": 1.1}, "Buffalo_L": {"cosine": 0.55, "euclidean": 0.6, "euclidean_l2": 1.1, "angular": 0.45},
} }
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4) threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)

View File

@ -9,7 +9,7 @@ from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet"] models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet"]
metrics = ["cosine", "euclidean", "euclidean_l2"] metrics = ["cosine", "euclidean", "euclidean_l2", "angular"]
detectors = ["opencv", "mtcnn"] detectors = ["opencv", "mtcnn"]