distance in find function implemented

This commit is contained in:
Şefik Serangil 2020-05-25 17:23:15 +03:00
parent c67ae7fe6a
commit 4440d82d44

View File

@ -435,7 +435,16 @@ def find(img_path, db_path
distances = []
for index, instance in df.iterrows():
source_representation = instance["representation"]
distance = dst.findCosineDistance(source_representation, target_representation)
if distance_metric == 'cosine':
distance = dst.findCosineDistance(source_representation, target_representation)
elif distance_metric == 'euclidean':
distance = dst.findEuclideanDistance(source_representation, target_representation)
elif distance_metric == 'euclidean_l2':
distance = dst.findEuclideanDistance(dst.l2_normalize(source_representation), dst.l2_normalize(target_representation))
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
distances.append(distance)
df["distance"] = distances