From 7280287a752815b37943dd21ea94f094065c931e Mon Sep 17 00:00:00 2001 From: YoussefAboelwafa Date: Sun, 20 Oct 2024 16:59:01 +0300 Subject: [PATCH] Add find_closest_embedding function in verification.py to find the index of the closest embedding to the current embedding from a list of embeddings --- deepface/DeepFace.py | 13 ++++++++++++ deepface/modules/verification.py | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index af5245f..f04dfd4 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -162,6 +162,19 @@ def verify( anti_spoofing=anti_spoofing, ) +def find_closest_embedding( + current_embedding: List[float], + embeddings_list: List[List[float]], + distance_metric: str = "cosine", + threshold: Optional[float] = 0.5, +) -> Optional[int]: + + return verification.find_closest_embedding( + current_embedding=current_embedding, + embeddings_list=embeddings_list, + distance_metric=distance_metric, + threshold=threshold, + ) def analyze( img_path: Union[str, np.ndarray], diff --git a/deepface/modules/verification.py b/deepface/modules/verification.py index 540b63b..7b1f4d4 100644 --- a/deepface/modules/verification.py +++ b/deepface/modules/verification.py @@ -211,6 +211,40 @@ def verify( return resp_obj +def find_closest_embedding( + current_embedding: List[float], + embeddings_list: List[List[float]], + distance_metric: str = "cosine", + threshold: Optional[float] = 0.5, +) -> Optional[int]: + """ + Find the index of the closest embedding to the current embedding from a list of embeddings. + + Args: + current_embedding (List[float]): The embedding of the current face. + embeddings_list (List[List[float]]): A list of embeddings to compare against. + distance_metric (str): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). + threshold (float): The maximum threshold used for verification. If the distance is below this threshold, the embeddings are considered a match. + + Returns: + Optional[int]: The index of the closest embedding if the distance is less than the threshold; otherwise, None. + """ + min_distance = float("inf") + min_idx = None + + for idx, embedding in enumerate(embeddings_list): + distance = find_distance(current_embedding, embedding, distance_metric) + if distance < min_distance: + min_distance = distance + min_idx = idx + + if threshold is None: + threshold = find_threshold("VGG-Face", distance_metric) # Default model and metric + + if min_distance <= threshold: + return min_idx + else: + return None def __extract_faces_and_embeddings( img_path: Union[str, np.ndarray],