enhancement for find function

- detect replaced files already in data store
- store column names in the pickle
This commit is contained in:
Sefik Ilkin Serengil 2024-03-08 12:01:31 +00:00
parent ad53a9bb28
commit 07a2d5bf62
2 changed files with 88 additions and 39 deletions

View File

@ -1,3 +1,6 @@
# built-in dependencies
import hashlib
# 3rd party dependencies
import tensorflow as tf
@ -14,3 +17,16 @@ def get_tf_major_version() -> int:
major_version (int)
"""
return int(tf.__version__.split(".", maxsplit=1)[0])
def find_hash_of_file(file_path: str) -> str:
"""
Find hash of image file
Args:
file_path (str): exact image path
Returns:
hash (str): digest with sha1 algorithm
"""
with open(file_path, "rb") as f:
digest = hashlib.sha1(f.read()).hexdigest()
return digest

View File

@ -1,7 +1,7 @@
# built-in dependencies
import os
import pickle
from typing import List, Union, Optional
from typing import List, Union, Optional, Dict, Any
import time
# 3rd party dependencies
@ -11,6 +11,7 @@ from tqdm import tqdm
# project dependencies
from deepface.commons.logger import Logger
from deepface.commons import package_utils
from deepface.modules import representation, detection, modeling, verification
from deepface.models.FacialRecognition import FacialRecognition
@ -97,14 +98,16 @@ def find(
# ---------------------------------------
file_name = f"representations_{model_name}.pkl"
file_name = file_name.replace("-", "_").lower()
file_name = f"ds_{model_name}_{detector_backend}_v2.pkl"
file_name = file_name.replace("-", "").lower()
datastore_path = os.path.join(db_path, file_name)
representations = []
# required columns for representations
df_cols = [
"identity",
f"{model_name}_representation",
"hash",
"embedding",
"target_x",
"target_y",
"target_w",
@ -120,14 +123,18 @@ def find(
with open(datastore_path, "rb") as f:
representations = pickle.load(f)
# check each item of representations list has required keys
for i, current_representation in enumerate(representations):
missing_keys = list(set(df_cols) - set(current_representation.keys()))
if len(missing_keys) > 0:
raise ValueError(
f"{i}-th item does not have some required keys - {missing_keys}."
f"Consider to delete {datastore_path}"
)
# Check if the representations are out-of-date
if len(representations) > 0:
if len(representations[0]) != len(df_cols):
raise ValueError(
f"Seems existing {datastore_path} is out-of-the-date."
"Please delete it and re-run."
)
pickled_images = [representation[0] for representation in representations]
pickled_images = [representation["identity"] for representation in representations]
else:
pickled_images = []
@ -142,9 +149,25 @@ def find(
if not silent and (len(new_images) > 0 or len(old_images) > 0):
logger.info(f"Found {len(new_images)} new images and {len(old_images)} removed images")
# detect replaced images
replaced_images = []
for current_representation in representations:
identity = current_representation["identity"]
if identity in old_images:
continue
alpha_hash = current_representation["hash"]
beta_hash = package_utils.find_hash_of_file(identity)
if alpha_hash != beta_hash:
logger.warn(f"Even though {identity} represented before, it's replaced later.")
replaced_images.append(identity)
# append replaced images into both old and new images. these will be dropped and re-added.
new_images = new_images + replaced_images
old_images = old_images + replaced_images
# remove old images first
if len(old_images) > 0:
representations = [rep for rep in representations if rep[0] not in old_images]
representations = [rep for rep in representations if rep["identity"] not in old_images]
must_save_pickle = True
# find representations for new images
@ -176,10 +199,7 @@ def find(
# ----------------------------
# now, we got representations for facial database
df = pd.DataFrame(
representations,
columns=df_cols,
)
df = pd.DataFrame(representations)
# img path might have more than once face
source_objs = detection.extract_faces(
@ -216,7 +236,7 @@ def find(
distances = []
for _, instance in df.iterrows():
source_representation = instance[f"{model_name}_representation"]
source_representation = instance["embedding"]
if source_representation is None:
distances.append(float("inf")) # no representation for this image
continue
@ -254,7 +274,7 @@ def find(
result_df["threshold"] = target_threshold
result_df["distance"] = distances
result_df = result_df.drop(columns=[f"{model_name}_representation"])
result_df = result_df.drop(columns=["embedding"])
# pylint: disable=unsubscriptable-object
result_df = result_df[result_df["distance"] <= target_threshold]
result_df = result_df.sort_values(by=["distance"], ascending=True).reset_index(drop=True)
@ -297,7 +317,7 @@ def __find_bulk_embeddings(
expand_percentage: int = 0,
normalization: str = "base",
silent: bool = False,
):
) -> List[Dict["str", Any]]:
"""
Find embeddings of a list of images
@ -323,8 +343,8 @@ def __find_bulk_embeddings(
silent (bool): enable or disable informative logging
Returns:
representations (list): pivot list of embeddings with
image name and detected face area's coordinates
representations (list): pivot list of dict with
image name, hash, embedding and detected face area's coordinates
"""
representations = []
for employee in tqdm(
@ -332,6 +352,8 @@ def __find_bulk_embeddings(
desc="Finding representations",
disable=silent,
):
file_hash = package_utils.find_hash_of_file(employee)
try:
img_objs = detection.extract_faces(
img_path=employee,
@ -342,15 +364,23 @@ def __find_bulk_embeddings(
align=align,
expand_percentage=expand_percentage,
)
except ValueError as err:
logger.error(
f"Exception while extracting faces from {employee}: {str(err)}"
)
logger.error(f"Exception while extracting faces from {employee}: {str(err)}")
img_objs = []
if len(img_objs) == 0:
logger.warn(f"No face detected in {employee}. It will be skipped in detection.")
representations.append((employee, None, 0, 0, 0, 0))
representations.append(
{
"identity": employee,
"hash": file_hash,
"embedding": None,
"target_x": 0,
"target_y": 0,
"target_w": 0,
"target_h": 0,
}
)
else:
for img_obj in img_objs:
img_content = img_obj["face"]
@ -365,13 +395,16 @@ def __find_bulk_embeddings(
)
img_representation = embedding_obj[0]["embedding"]
representations.append((
employee,
img_representation,
img_region["x"],
img_region["y"],
img_region["w"],
img_region["h"]
))
representations.append(
{
"identity": employee,
"hash": file_hash,
"embedding": img_representation,
"target_x": img_region["x"],
"target_y": img_region["y"],
"target_w": img_region["w"],
"target_h": img_region["h"],
}
)
return representations