Add find_closest_embedding function in verification.py to find the index of the closest embedding to the current embedding from a list of embeddings

This commit is contained in:
YoussefAboelwafa 2024-10-20 16:59:01 +03:00
parent e9eb4829fc
commit 7280287a75
2 changed files with 47 additions and 0 deletions

View File

@ -162,6 +162,19 @@ def verify(
anti_spoofing=anti_spoofing,
)
def find_closest_embedding(
current_embedding: List[float],
embeddings_list: List[List[float]],
distance_metric: str = "cosine",
threshold: Optional[float] = 0.5,
) -> Optional[int]:
return verification.find_closest_embedding(
current_embedding=current_embedding,
embeddings_list=embeddings_list,
distance_metric=distance_metric,
threshold=threshold,
)
def analyze(
img_path: Union[str, np.ndarray],

View File

@ -211,6 +211,40 @@ def verify(
return resp_obj
def find_closest_embedding(
current_embedding: List[float],
embeddings_list: List[List[float]],
distance_metric: str = "cosine",
threshold: Optional[float] = 0.5,
) -> Optional[int]:
"""
Find the index of the closest embedding to the current embedding from a list of embeddings.
Args:
current_embedding (List[float]): The embedding of the current face.
embeddings_list (List[List[float]]): A list of embeddings to compare against.
distance_metric (str): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine).
threshold (float): The maximum threshold used for verification. If the distance is below this threshold, the embeddings are considered a match.
Returns:
Optional[int]: The index of the closest embedding if the distance is less than the threshold; otherwise, None.
"""
min_distance = float("inf")
min_idx = None
for idx, embedding in enumerate(embeddings_list):
distance = find_distance(current_embedding, embedding, distance_metric)
if distance < min_distance:
min_distance = distance
min_idx = idx
if threshold is None:
threshold = find_threshold("VGG-Face", distance_metric) # Default model and metric
if min_distance <= threshold:
return min_idx
else:
return None
def __extract_faces_and_embeddings(
img_path: Union[str, np.ndarray],