mirror of
https://github.com/serengil/deepface.git
synced 2025-06-08 12:35:22 +00:00
wrapper for find distance added
This commit is contained in:
parent
2f9f9761d0
commit
d7c2998e1b
@ -250,21 +250,9 @@ def find(
|
|||||||
+ " after pickle created. Delete the {file_name} and re-run."
|
+ " after pickle created. Delete the {file_name} and re-run."
|
||||||
)
|
)
|
||||||
|
|
||||||
if distance_metric == "cosine":
|
distance = verification.find_distance(
|
||||||
distance = verification.find_cosine_distance(
|
source_representation, target_representation, distance_metric
|
||||||
source_representation, target_representation
|
|
||||||
)
|
)
|
||||||
elif distance_metric == "euclidean":
|
|
||||||
distance = verification.find_euclidean_distance(
|
|
||||||
source_representation, target_representation
|
|
||||||
)
|
|
||||||
elif distance_metric == "euclidean_l2":
|
|
||||||
distance = verification.find_euclidean_distance(
|
|
||||||
verification.l2_normalize(source_representation),
|
|
||||||
verification.l2_normalize(target_representation),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"invalid distance metric passes - {distance_metric}")
|
|
||||||
|
|
||||||
distances.append(distance)
|
distances.append(distance)
|
||||||
|
|
||||||
|
@ -141,16 +141,7 @@ def verify(
|
|||||||
regions = []
|
regions = []
|
||||||
for idx, img1_embedding in enumerate(img1_embeddings):
|
for idx, img1_embedding in enumerate(img1_embeddings):
|
||||||
for idy, img2_embedding in enumerate(img2_embeddings):
|
for idy, img2_embedding in enumerate(img2_embeddings):
|
||||||
if distance_metric == "cosine":
|
distance = find_distance(img1_embedding, img2_embedding, distance_metric)
|
||||||
distance = find_cosine_distance(img1_embedding, img2_embedding)
|
|
||||||
elif distance_metric == "euclidean":
|
|
||||||
distance = find_euclidean_distance(img1_embedding, img2_embedding)
|
|
||||||
elif distance_metric == "euclidean_l2":
|
|
||||||
distance = find_euclidean_distance(
|
|
||||||
l2_normalize(img1_embedding), l2_normalize(img2_embedding)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
|
||||||
distances.append(distance)
|
distances.append(distance)
|
||||||
regions.append((img1_objs[idx]["facial_area"], img2_objs[idy]["facial_area"]))
|
regions.append((img1_objs[idx]["facial_area"], img2_objs[idy]["facial_area"]))
|
||||||
|
|
||||||
@ -234,6 +225,32 @@ def l2_normalize(x: Union[np.ndarray, list]) -> np.ndarray:
|
|||||||
return x / np.sqrt(np.sum(np.multiply(x, x)))
|
return x / np.sqrt(np.sum(np.multiply(x, x)))
|
||||||
|
|
||||||
|
|
||||||
|
def find_distance(
|
||||||
|
alpha_embedding: Union[np.ndarray, list],
|
||||||
|
beta_embedding: Union[np.ndarray, list],
|
||||||
|
distance_metric: str,
|
||||||
|
) -> np.float64:
|
||||||
|
"""
|
||||||
|
Wrapper to find distance between vectors according to the given distance metric
|
||||||
|
Args:
|
||||||
|
source_representation (np.ndarray or list): 1st vector
|
||||||
|
test_representation (np.ndarray or list): 2nd vector
|
||||||
|
Returns
|
||||||
|
distance (np.float64): calculated cosine distance
|
||||||
|
"""
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
distance = find_cosine_distance(alpha_embedding, beta_embedding)
|
||||||
|
elif distance_metric == "euclidean":
|
||||||
|
distance = find_euclidean_distance(alpha_embedding, beta_embedding)
|
||||||
|
elif distance_metric == "euclidean_l2":
|
||||||
|
distance = find_euclidean_distance(
|
||||||
|
l2_normalize(alpha_embedding), l2_normalize(beta_embedding)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
||||||
|
return distance
|
||||||
|
|
||||||
|
|
||||||
def find_threshold(model_name: str, distance_metric: str) -> float:
|
def find_threshold(model_name: str, distance_metric: str) -> float:
|
||||||
"""
|
"""
|
||||||
Retrieve pre-tuned threshold values for a model and distance metric pair
|
Retrieve pre-tuned threshold values for a model and distance metric pair
|
||||||
|
Loading…
x
Reference in New Issue
Block a user