mirror of
https://github.com/serengil/deepface.git
synced 2025-06-09 04:55:24 +00:00
Simplify find data initialization
This commit is contained in:
parent
14bbc2f938
commit
411df327bf
@ -100,6 +100,7 @@ def find(
|
|||||||
file_name = f"representations_{model_name}.pkl"
|
file_name = f"representations_{model_name}.pkl"
|
||||||
file_name = file_name.replace("-", "_").lower()
|
file_name = file_name.replace("-", "_").lower()
|
||||||
datastore_path = os.path.join(db_path, file_name)
|
datastore_path = os.path.join(db_path, file_name)
|
||||||
|
representations = []
|
||||||
|
|
||||||
df_cols = [
|
df_cols = [
|
||||||
"identity",
|
"identity",
|
||||||
@ -110,77 +111,48 @@ def find(
|
|||||||
"target_h",
|
"target_h",
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.path.exists(datastore_path):
|
# Ensure the proper pickle file exists
|
||||||
with open(datastore_path, "rb") as f:
|
if not os.path.exists(datastore_path):
|
||||||
representations = pickle.load(f)
|
with open(datastore_path, "wb") as f:
|
||||||
|
pickle.dump([], f)
|
||||||
|
f.close()
|
||||||
|
|
||||||
if len(representations) > 0 and len(representations[0]) != len(df_cols):
|
# Load the representations from the pickle file
|
||||||
raise ValueError(
|
with open(datastore_path, "rb") as f:
|
||||||
f"Seems existing {datastore_path} is out-of-the-date."
|
representations = pickle.load(f)
|
||||||
"Please delete it and re-run."
|
f.close()
|
||||||
)
|
|
||||||
|
|
||||||
alpha_employees = __list_images(path=db_path)
|
# Check if the representations are out-of-date
|
||||||
beta_employees = [representation[0] for representation in representations]
|
if len(representations) > 0:
|
||||||
|
if len(representations[0]) != len(df_cols):
|
||||||
newbies = list(set(alpha_employees) - set(beta_employees))
|
|
||||||
oldies = list(set(beta_employees) - set(alpha_employees))
|
|
||||||
|
|
||||||
if newbies:
|
|
||||||
logger.warn(
|
|
||||||
f"Items {newbies} were added into {db_path}"
|
|
||||||
f" just after data source {datastore_path} created!"
|
|
||||||
)
|
|
||||||
newbies_representations = __find_bulk_embeddings(
|
|
||||||
employees=newbies,
|
|
||||||
model_name=model_name,
|
|
||||||
target_size=target_size,
|
|
||||||
detector_backend=detector_backend,
|
|
||||||
enforce_detection=enforce_detection,
|
|
||||||
align=align,
|
|
||||||
normalization=normalization,
|
|
||||||
silent=silent,
|
|
||||||
)
|
|
||||||
representations = representations + newbies_representations
|
|
||||||
|
|
||||||
if oldies:
|
|
||||||
logger.warn(
|
|
||||||
f"Items {oldies} were dropped from {db_path}"
|
|
||||||
f" just after data source {datastore_path} created!"
|
|
||||||
)
|
|
||||||
representations = [rep for rep in representations if rep[0] not in oldies]
|
|
||||||
|
|
||||||
if newbies or oldies:
|
|
||||||
if len(representations) == 0:
|
|
||||||
raise ValueError(f"There is no image in {db_path} anymore!")
|
|
||||||
|
|
||||||
# save new representations
|
|
||||||
with open(datastore_path, "wb") as f:
|
|
||||||
pickle.dump(representations, f)
|
|
||||||
|
|
||||||
if not silent:
|
|
||||||
logger.info(
|
|
||||||
f"{len(newbies)} new representations are just added"
|
|
||||||
f" whereas {len(oldies)} represented one(s) are just dropped"
|
|
||||||
f" in {os.path.join(db_path,file_name)} file."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not silent:
|
|
||||||
logger.info(f"There are {len(representations)} representations found in {file_name}")
|
|
||||||
|
|
||||||
else: # create representation.pkl from scratch
|
|
||||||
employees = __list_images(path=db_path)
|
|
||||||
|
|
||||||
if len(employees) == 0:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Could not find any valid image in {db_path} folder!"
|
f"Seems existing {datastore_path} is out-of-the-date."
|
||||||
"Valid images are .jpg, .jpeg or .png files.",
|
"Please delete it and re-run."
|
||||||
)
|
)
|
||||||
|
pickled_images = [representation[0] for representation in representations]
|
||||||
|
else:
|
||||||
|
pickled_images = []
|
||||||
|
|
||||||
# ------------------------
|
# Get the list of images on storage
|
||||||
# find representations for db images
|
storage_images = __list_images(path=db_path)
|
||||||
representations = __find_bulk_embeddings(
|
|
||||||
employees=employees,
|
# Enforce data consistency amongst on disk images and pickle file
|
||||||
|
must_save_pickle = False
|
||||||
|
new_images = list(set(storage_images) - set(pickled_images)) # images added to storage
|
||||||
|
old_images = list(set(pickled_images) - set(storage_images)) # images removed from storage
|
||||||
|
|
||||||
|
if not silent:
|
||||||
|
logger.info(f"Found {len(new_images)} new images and {len(old_images)} removed images")
|
||||||
|
|
||||||
|
# remove old images first
|
||||||
|
if len(old_images)>0:
|
||||||
|
representations = [rep for rep in representations if rep[0] not in old_images]
|
||||||
|
must_save_pickle = True
|
||||||
|
|
||||||
|
# find representations for new images
|
||||||
|
if len(new_images)>0:
|
||||||
|
representations += __find_bulk_embeddings(
|
||||||
|
employees=new_images,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
target_size=target_size,
|
target_size=target_size,
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
@ -188,15 +160,22 @@ def find(
|
|||||||
align=align,
|
align=align,
|
||||||
normalization=normalization,
|
normalization=normalization,
|
||||||
silent=silent,
|
silent=silent,
|
||||||
)
|
) # add new images
|
||||||
|
must_save_pickle = True
|
||||||
# -------------------------------
|
|
||||||
|
|
||||||
|
if must_save_pickle:
|
||||||
with open(datastore_path, "wb") as f:
|
with open(datastore_path, "wb") as f:
|
||||||
pickle.dump(representations, f)
|
pickle.dump(representations, f)
|
||||||
|
f.close()
|
||||||
if not silent:
|
if not silent:
|
||||||
logger.info(f"Representations stored in {datastore_path} file.")
|
logger.info(f"There are now {len(representations)} representations in {file_name}")
|
||||||
|
|
||||||
|
# Should we have no representations bailout
|
||||||
|
if len(representations) == 0:
|
||||||
|
toc = time.time()
|
||||||
|
if not silent:
|
||||||
|
logger.info(f"find function duration {toc - tic} seconds")
|
||||||
|
return []
|
||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# now, we got representations for facial database
|
# now, we got representations for facial database
|
||||||
@ -290,7 +269,7 @@ def find(
|
|||||||
toc = time.time()
|
toc = time.time()
|
||||||
|
|
||||||
if not silent:
|
if not silent:
|
||||||
logger.info(f"find function lasts {toc - tic} seconds")
|
logger.info(f"find function duration {toc - tic} seconds")
|
||||||
|
|
||||||
return resp_obj
|
return resp_obj
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user