mirror of
https://github.com/serengil/deepface.git
synced 2025-06-08 20:45:22 +00:00
avoid dimension imcompability error
created pickle may have 2622 dimensional vectors but VGG-Face is not creating 4096 dimensional vectors. If they are mismatch, then raise a meaningful error
This commit is contained in:
parent
05013e550c
commit
0eb1515e11
@ -616,6 +616,15 @@ def find(
|
||||
for index, instance in df.iterrows():
|
||||
source_representation = instance[f"{model_name}_representation"]
|
||||
|
||||
target_dims = len(list(target_representation))
|
||||
source_dims = len(list(source_representation))
|
||||
if target_dims != source_dims:
|
||||
raise ValueError(
|
||||
"Source and target embeddings must have same dimensions but "
|
||||
+ f"{target_dims}:{source_dims}. Model structure may change"
|
||||
+ " after pickle created. Delete the {file_name} and re-run."
|
||||
)
|
||||
|
||||
if distance_metric == "cosine":
|
||||
distance = dst.findCosineDistance(source_representation, target_representation)
|
||||
elif distance_metric == "euclidean":
|
||||
@ -636,6 +645,7 @@ def find(
|
||||
|
||||
threshold = dst.findThreshold(model_name, distance_metric)
|
||||
result_df = result_df.drop(columns=[f"{model_name}_representation"])
|
||||
# pylint: disable=unsubscriptable-object
|
||||
result_df = result_df[result_df[f"{model_name}_{distance_metric}"] <= threshold]
|
||||
result_df = result_df.sort_values(
|
||||
by=[f"{model_name}_{distance_metric}"], ascending=True
|
||||
|
Loading…
x
Reference in New Issue
Block a user