Merge pull request #1032 from AndreaLanfranchi/al20240222-find-simpler

Simplify find data initialization
This commit is contained in:
Sefik Ilkin Serengil 2024-02-22 14:01:55 +00:00 committed by GitHub
commit af73d3dfe4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 77 deletions

View File

@ -100,6 +100,7 @@ def find(
file_name = f"representations_{model_name}.pkl"
file_name = file_name.replace("-", "_").lower()
datastore_path = os.path.join(db_path, file_name)
representations = []
df_cols = [
"identity",
@ -110,77 +111,46 @@ def find(
"target_h",
]
if os.path.exists(datastore_path):
with open(datastore_path, "rb") as f:
representations = pickle.load(f)
# Ensure the proper pickle file exists
if not os.path.exists(datastore_path):
with open(datastore_path, "wb") as f:
pickle.dump([], f)
if len(representations) > 0 and len(representations[0]) != len(df_cols):
raise ValueError(
f"Seems existing {datastore_path} is out-of-the-date."
"Please delete it and re-run."
)
# Load the representations from the pickle file
with open(datastore_path, "rb") as f:
representations = pickle.load(f)
alpha_employees = __list_images(path=db_path)
beta_employees = [representation[0] for representation in representations]
newbies = list(set(alpha_employees) - set(beta_employees))
oldies = list(set(beta_employees) - set(alpha_employees))
if newbies:
logger.warn(
f"Items {newbies} were added into {db_path}"
f" just after data source {datastore_path} created!"
)
newbies_representations = __find_bulk_embeddings(
employees=newbies,
model_name=model_name,
target_size=target_size,
detector_backend=detector_backend,
enforce_detection=enforce_detection,
align=align,
normalization=normalization,
silent=silent,
)
representations = representations + newbies_representations
if oldies:
logger.warn(
f"Items {oldies} were dropped from {db_path}"
f" just after data source {datastore_path} created!"
)
representations = [rep for rep in representations if rep[0] not in oldies]
if newbies or oldies:
if len(representations) == 0:
raise ValueError(f"There is no image in {db_path} anymore!")
# save new representations
with open(datastore_path, "wb") as f:
pickle.dump(representations, f)
if not silent:
logger.info(
f"{len(newbies)} new representations are just added"
f" whereas {len(oldies)} represented one(s) are just dropped"
f" in {os.path.join(db_path,file_name)} file."
)
if not silent:
logger.info(f"There are {len(representations)} representations found in {file_name}")
else: # create representation.pkl from scratch
employees = __list_images(path=db_path)
if len(employees) == 0:
# Check if the representations are out-of-date
if len(representations) > 0:
if len(representations[0]) != len(df_cols):
raise ValueError(
f"Could not find any valid image in {db_path} folder!"
"Valid images are .jpg, .jpeg or .png files.",
f"Seems existing {datastore_path} is out-of-the-date."
"Please delete it and re-run."
)
pickled_images = [representation[0] for representation in representations]
else:
pickled_images = []
# ------------------------
# find representations for db images
representations = __find_bulk_embeddings(
employees=employees,
# Get the list of images on storage
storage_images = __list_images(path=db_path)
# Enforce data consistency amongst on disk images and pickle file
must_save_pickle = False
new_images = list(set(storage_images) - set(pickled_images)) # images added to storage
old_images = list(set(pickled_images) - set(storage_images)) # images removed from storage
if not silent and (len(new_images) > 0 or len(old_images) > 0):
logger.info(f"Found {len(new_images)} new images and {len(old_images)} removed images")
# remove old images first
if len(old_images)>0:
representations = [rep for rep in representations if rep[0] not in old_images]
must_save_pickle = True
# find representations for new images
if len(new_images)>0:
representations += __find_bulk_embeddings(
employees=new_images,
model_name=model_name,
target_size=target_size,
detector_backend=detector_backend,
@ -188,15 +158,21 @@ def find(
align=align,
normalization=normalization,
silent=silent,
)
# -------------------------------
) # add new images
must_save_pickle = True
if must_save_pickle:
with open(datastore_path, "wb") as f:
pickle.dump(representations, f)
if not silent:
logger.info(f"Representations stored in {datastore_path} file.")
logger.info(f"There are now {len(representations)} representations in {file_name}")
# Should we have no representations bailout
if len(representations) == 0:
if not silent:
toc = time.time()
logger.info(f"find function duration {toc - tic} seconds")
return []
# ----------------------------
# now, we got representations for facial database
@ -287,10 +263,9 @@ def find(
# -----------------------------------
toc = time.time()
if not silent:
logger.info(f"find function lasts {toc - tic} seconds")
toc = time.time()
logger.info(f"find function duration {toc - tic} seconds")
return resp_obj

View File

@ -1,3 +1,4 @@
import os
import cv2
import pandas as pd
from deepface import DeepFace
@ -10,7 +11,7 @@ threshold = verification.find_threshold(model_name="VGG-Face", distance_metric="
def test_find_with_exact_path():
img_path = "dataset/img1.jpg"
img_path = os.path.join("dataset","img1.jpg")
dfs = DeepFace.find(img_path=img_path, db_path="dataset", silent=True)
assert len(dfs) > 0
for df in dfs:
@ -30,7 +31,7 @@ def test_find_with_exact_path():
def test_find_with_array_input():
img_path = "dataset/img1.jpg"
img_path = os.path.join("dataset","img1.jpg")
img1 = cv2.imread(img_path)
dfs = DeepFace.find(img1, db_path="dataset", silent=True)
assert len(dfs) > 0
@ -52,7 +53,7 @@ def test_find_with_array_input():
def test_find_with_extracted_faces():
img_path = "dataset/img1.jpg"
img_path = os.path.join("dataset","img1.jpg")
face_objs = DeepFace.extract_faces(img_path)
img = face_objs[0]["face"]
dfs = DeepFace.find(img, db_path="dataset", detector_backend="skip", silent=True)