wrapper for find distance added

This commit is contained in:
Sefik Ilkin Serengil 2024-03-08 13:55:02 +00:00
parent 2f9f9761d0
commit d7c2998e1b
2 changed files with 30 additions and 25 deletions

View File

@ -250,21 +250,9 @@ def find(
+ " after pickle created. Delete the {file_name} and re-run."
)
if distance_metric == "cosine":
distance = verification.find_cosine_distance(
source_representation, target_representation
distance = verification.find_distance(
source_representation, target_representation, distance_metric
)
elif distance_metric == "euclidean":
distance = verification.find_euclidean_distance(
source_representation, target_representation
)
elif distance_metric == "euclidean_l2":
distance = verification.find_euclidean_distance(
verification.l2_normalize(source_representation),
verification.l2_normalize(target_representation),
)
else:
raise ValueError(f"invalid distance metric passes - {distance_metric}")
distances.append(distance)

View File

@ -141,16 +141,7 @@ def verify(
regions = []
for idx, img1_embedding in enumerate(img1_embeddings):
for idy, img2_embedding in enumerate(img2_embeddings):
if distance_metric == "cosine":
distance = find_cosine_distance(img1_embedding, img2_embedding)
elif distance_metric == "euclidean":
distance = find_euclidean_distance(img1_embedding, img2_embedding)
elif distance_metric == "euclidean_l2":
distance = find_euclidean_distance(
l2_normalize(img1_embedding), l2_normalize(img2_embedding)
)
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
distance = find_distance(img1_embedding, img2_embedding, distance_metric)
distances.append(distance)
regions.append((img1_objs[idx]["facial_area"], img2_objs[idy]["facial_area"]))
@ -234,6 +225,32 @@ def l2_normalize(x: Union[np.ndarray, list]) -> np.ndarray:
return x / np.sqrt(np.sum(np.multiply(x, x)))
def find_distance(
alpha_embedding: Union[np.ndarray, list],
beta_embedding: Union[np.ndarray, list],
distance_metric: str,
) -> np.float64:
"""
Wrapper to find distance between vectors according to the given distance metric
Args:
source_representation (np.ndarray or list): 1st vector
test_representation (np.ndarray or list): 2nd vector
Returns
distance (np.float64): calculated cosine distance
"""
if distance_metric == "cosine":
distance = find_cosine_distance(alpha_embedding, beta_embedding)
elif distance_metric == "euclidean":
distance = find_euclidean_distance(alpha_embedding, beta_embedding)
elif distance_metric == "euclidean_l2":
distance = find_euclidean_distance(
l2_normalize(alpha_embedding), l2_normalize(beta_embedding)
)
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
return distance
def find_threshold(model_name: str, distance_metric: str) -> float:
"""
Retrieve pre-tuned threshold values for a model and distance metric pair