avoid dimension imcompability error

created pickle may have 2622 dimensional vectors but VGG-Face is not creating
4096 dimensional vectors. If they are mismatch, then raise a meaningful error
This commit is contained in:
Sefik Ilkin Serengil 2024-01-08 16:59:20 +00:00
parent 05013e550c
commit 0eb1515e11

View File

@ -616,6 +616,15 @@ def find(
for index, instance in df.iterrows(): for index, instance in df.iterrows():
source_representation = instance[f"{model_name}_representation"] source_representation = instance[f"{model_name}_representation"]
target_dims = len(list(target_representation))
source_dims = len(list(source_representation))
if target_dims != source_dims:
raise ValueError(
"Source and target embeddings must have same dimensions but "
+ f"{target_dims}:{source_dims}. Model structure may change"
+ " after pickle created. Delete the {file_name} and re-run."
)
if distance_metric == "cosine": if distance_metric == "cosine":
distance = dst.findCosineDistance(source_representation, target_representation) distance = dst.findCosineDistance(source_representation, target_representation)
elif distance_metric == "euclidean": elif distance_metric == "euclidean":
@ -636,6 +645,7 @@ def find(
threshold = dst.findThreshold(model_name, distance_metric) threshold = dst.findThreshold(model_name, distance_metric)
result_df = result_df.drop(columns=[f"{model_name}_representation"]) result_df = result_df.drop(columns=[f"{model_name}_representation"])
# pylint: disable=unsubscriptable-object
result_df = result_df[result_df[f"{model_name}_{distance_metric}"] <= threshold] result_df = result_df[result_df[f"{model_name}_{distance_metric}"] <= threshold]
result_df = result_df.sort_values( result_df = result_df.sort_values(
by=[f"{model_name}_{distance_metric}"], ascending=True by=[f"{model_name}_{distance_metric}"], ascending=True