mirror of
https://github.com/serengil/deepface.git
synced 2025-06-08 12:35:22 +00:00
wrapper for find distance added
This commit is contained in:
parent
2f9f9761d0
commit
d7c2998e1b
@ -250,21 +250,9 @@ def find(
|
||||
+ " after pickle created. Delete the {file_name} and re-run."
|
||||
)
|
||||
|
||||
if distance_metric == "cosine":
|
||||
distance = verification.find_cosine_distance(
|
||||
source_representation, target_representation
|
||||
distance = verification.find_distance(
|
||||
source_representation, target_representation, distance_metric
|
||||
)
|
||||
elif distance_metric == "euclidean":
|
||||
distance = verification.find_euclidean_distance(
|
||||
source_representation, target_representation
|
||||
)
|
||||
elif distance_metric == "euclidean_l2":
|
||||
distance = verification.find_euclidean_distance(
|
||||
verification.l2_normalize(source_representation),
|
||||
verification.l2_normalize(target_representation),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"invalid distance metric passes - {distance_metric}")
|
||||
|
||||
distances.append(distance)
|
||||
|
||||
|
@ -141,16 +141,7 @@ def verify(
|
||||
regions = []
|
||||
for idx, img1_embedding in enumerate(img1_embeddings):
|
||||
for idy, img2_embedding in enumerate(img2_embeddings):
|
||||
if distance_metric == "cosine":
|
||||
distance = find_cosine_distance(img1_embedding, img2_embedding)
|
||||
elif distance_metric == "euclidean":
|
||||
distance = find_euclidean_distance(img1_embedding, img2_embedding)
|
||||
elif distance_metric == "euclidean_l2":
|
||||
distance = find_euclidean_distance(
|
||||
l2_normalize(img1_embedding), l2_normalize(img2_embedding)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
||||
distance = find_distance(img1_embedding, img2_embedding, distance_metric)
|
||||
distances.append(distance)
|
||||
regions.append((img1_objs[idx]["facial_area"], img2_objs[idy]["facial_area"]))
|
||||
|
||||
@ -234,6 +225,32 @@ def l2_normalize(x: Union[np.ndarray, list]) -> np.ndarray:
|
||||
return x / np.sqrt(np.sum(np.multiply(x, x)))
|
||||
|
||||
|
||||
def find_distance(
|
||||
alpha_embedding: Union[np.ndarray, list],
|
||||
beta_embedding: Union[np.ndarray, list],
|
||||
distance_metric: str,
|
||||
) -> np.float64:
|
||||
"""
|
||||
Wrapper to find distance between vectors according to the given distance metric
|
||||
Args:
|
||||
source_representation (np.ndarray or list): 1st vector
|
||||
test_representation (np.ndarray or list): 2nd vector
|
||||
Returns
|
||||
distance (np.float64): calculated cosine distance
|
||||
"""
|
||||
if distance_metric == "cosine":
|
||||
distance = find_cosine_distance(alpha_embedding, beta_embedding)
|
||||
elif distance_metric == "euclidean":
|
||||
distance = find_euclidean_distance(alpha_embedding, beta_embedding)
|
||||
elif distance_metric == "euclidean_l2":
|
||||
distance = find_euclidean_distance(
|
||||
l2_normalize(alpha_embedding), l2_normalize(beta_embedding)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
||||
return distance
|
||||
|
||||
|
||||
def find_threshold(model_name: str, distance_metric: str) -> float:
|
||||
"""
|
||||
Retrieve pre-tuned threshold values for a model and distance metric pair
|
||||
|
Loading…
x
Reference in New Issue
Block a user