diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 35e4f5b..3cb97a0 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -435,7 +435,16 @@ def find(img_path, db_path distances = [] for index, instance in df.iterrows(): source_representation = instance["representation"] - distance = dst.findCosineDistance(source_representation, target_representation) + + if distance_metric == 'cosine': + distance = dst.findCosineDistance(source_representation, target_representation) + elif distance_metric == 'euclidean': + distance = dst.findEuclideanDistance(source_representation, target_representation) + elif distance_metric == 'euclidean_l2': + distance = dst.findEuclideanDistance(dst.l2_normalize(source_representation), dst.l2_normalize(target_representation)) + else: + raise ValueError("Invalid distance_metric passed - ", distance_metric) + distances.append(distance) df["distance"] = distances