From 07a2d5bf627eecc1feaf307b7a7ef18852971511 Mon Sep 17 00:00:00 2001 From: Sefik Ilkin Serengil Date: Fri, 8 Mar 2024 12:01:31 +0000 Subject: [PATCH] enhancement for find function - detect replaced files already in data store - store column names in the pickle --- deepface/commons/package_utils.py | 16 +++++ deepface/modules/recognition.py | 111 +++++++++++++++++++----------- 2 files changed, 88 insertions(+), 39 deletions(-) diff --git a/deepface/commons/package_utils.py b/deepface/commons/package_utils.py index 9326b94..2226e07 100644 --- a/deepface/commons/package_utils.py +++ b/deepface/commons/package_utils.py @@ -1,3 +1,6 @@ +# built-in dependencies +import hashlib + # 3rd party dependencies import tensorflow as tf @@ -14,3 +17,16 @@ def get_tf_major_version() -> int: major_version (int) """ return int(tf.__version__.split(".", maxsplit=1)[0]) + + +def find_hash_of_file(file_path: str) -> str: + """ + Find hash of image file + Args: + file_path (str): exact image path + Returns: + hash (str): digest with sha1 algorithm + """ + with open(file_path, "rb") as f: + digest = hashlib.sha1(f.read()).hexdigest() + return digest diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index fa2b3b7..b3a8a11 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -1,7 +1,7 @@ # built-in dependencies import os import pickle -from typing import List, Union, Optional +from typing import List, Union, Optional, Dict, Any import time # 3rd party dependencies @@ -11,6 +11,7 @@ from tqdm import tqdm # project dependencies from deepface.commons.logger import Logger +from deepface.commons import package_utils from deepface.modules import representation, detection, modeling, verification from deepface.models.FacialRecognition import FacialRecognition @@ -97,14 +98,16 @@ def find( # --------------------------------------- - file_name = f"representations_{model_name}.pkl" - file_name = file_name.replace("-", "_").lower() + file_name = f"ds_{model_name}_{detector_backend}_v2.pkl" + file_name = file_name.replace("-", "").lower() datastore_path = os.path.join(db_path, file_name) representations = [] + # required columns for representations df_cols = [ "identity", - f"{model_name}_representation", + "hash", + "embedding", "target_x", "target_y", "target_w", @@ -120,14 +123,18 @@ def find( with open(datastore_path, "rb") as f: representations = pickle.load(f) + # check each item of representations list has required keys + for i, current_representation in enumerate(representations): + missing_keys = list(set(df_cols) - set(current_representation.keys())) + if len(missing_keys) > 0: + raise ValueError( + f"{i}-th item does not have some required keys - {missing_keys}." + f"Consider to delete {datastore_path}" + ) + # Check if the representations are out-of-date if len(representations) > 0: - if len(representations[0]) != len(df_cols): - raise ValueError( - f"Seems existing {datastore_path} is out-of-the-date." - "Please delete it and re-run." - ) - pickled_images = [representation[0] for representation in representations] + pickled_images = [representation["identity"] for representation in representations] else: pickled_images = [] @@ -136,19 +143,35 @@ def find( # Enforce data consistency amongst on disk images and pickle file must_save_pickle = False - new_images = list(set(storage_images) - set(pickled_images)) # images added to storage - old_images = list(set(pickled_images) - set(storage_images)) # images removed from storage + new_images = list(set(storage_images) - set(pickled_images)) # images added to storage + old_images = list(set(pickled_images) - set(storage_images)) # images removed from storage if not silent and (len(new_images) > 0 or len(old_images) > 0): logger.info(f"Found {len(new_images)} new images and {len(old_images)} removed images") + # detect replaced images + replaced_images = [] + for current_representation in representations: + identity = current_representation["identity"] + if identity in old_images: + continue + alpha_hash = current_representation["hash"] + beta_hash = package_utils.find_hash_of_file(identity) + if alpha_hash != beta_hash: + logger.warn(f"Even though {identity} represented before, it's replaced later.") + replaced_images.append(identity) + + # append replaced images into both old and new images. these will be dropped and re-added. + new_images = new_images + replaced_images + old_images = old_images + replaced_images + # remove old images first - if len(old_images)>0: - representations = [rep for rep in representations if rep[0] not in old_images] + if len(old_images) > 0: + representations = [rep for rep in representations if rep["identity"] not in old_images] must_save_pickle = True # find representations for new images - if len(new_images)>0: + if len(new_images) > 0: representations += __find_bulk_embeddings( employees=new_images, model_name=model_name, @@ -158,7 +181,7 @@ def find( align=align, normalization=normalization, silent=silent, - ) # add new images + ) # add new images must_save_pickle = True if must_save_pickle: @@ -176,10 +199,7 @@ def find( # ---------------------------- # now, we got representations for facial database - df = pd.DataFrame( - representations, - columns=df_cols, - ) + df = pd.DataFrame(representations) # img path might have more than once face source_objs = detection.extract_faces( @@ -216,9 +236,9 @@ def find( distances = [] for _, instance in df.iterrows(): - source_representation = instance[f"{model_name}_representation"] + source_representation = instance["embedding"] if source_representation is None: - distances.append(float("inf")) # no representation for this image + distances.append(float("inf")) # no representation for this image continue target_dims = len(list(target_representation)) @@ -254,7 +274,7 @@ def find( result_df["threshold"] = target_threshold result_df["distance"] = distances - result_df = result_df.drop(columns=[f"{model_name}_representation"]) + result_df = result_df.drop(columns=["embedding"]) # pylint: disable=unsubscriptable-object result_df = result_df[result_df["distance"] <= target_threshold] result_df = result_df.sort_values(by=["distance"], ascending=True).reset_index(drop=True) @@ -297,7 +317,7 @@ def __find_bulk_embeddings( expand_percentage: int = 0, normalization: str = "base", silent: bool = False, -): +) -> List[Dict["str", Any]]: """ Find embeddings of a list of images @@ -323,8 +343,8 @@ def __find_bulk_embeddings( silent (bool): enable or disable informative logging Returns: - representations (list): pivot list of embeddings with - image name and detected face area's coordinates + representations (list): pivot list of dict with + image name, hash, embedding and detected face area's coordinates """ representations = [] for employee in tqdm( @@ -332,6 +352,8 @@ def __find_bulk_embeddings( desc="Finding representations", disable=silent, ): + file_hash = package_utils.find_hash_of_file(employee) + try: img_objs = detection.extract_faces( img_path=employee, @@ -342,15 +364,23 @@ def __find_bulk_embeddings( align=align, expand_percentage=expand_percentage, ) + except ValueError as err: - logger.error( - f"Exception while extracting faces from {employee}: {str(err)}" - ) + logger.error(f"Exception while extracting faces from {employee}: {str(err)}") img_objs = [] if len(img_objs) == 0: - logger.warn(f"No face detected in {employee}. It will be skipped in detection.") - representations.append((employee, None, 0, 0, 0, 0)) + representations.append( + { + "identity": employee, + "hash": file_hash, + "embedding": None, + "target_x": 0, + "target_y": 0, + "target_w": 0, + "target_h": 0, + } + ) else: for img_obj in img_objs: img_content = img_obj["face"] @@ -365,13 +395,16 @@ def __find_bulk_embeddings( ) img_representation = embedding_obj[0]["embedding"] - representations.append(( - employee, - img_representation, - img_region["x"], - img_region["y"], - img_region["w"], - img_region["h"] - )) + representations.append( + { + "identity": employee, + "hash": file_hash, + "embedding": img_representation, + "target_x": img_region["x"], + "target_y": img_region["y"], + "target_w": img_region["w"], + "target_h": img_region["h"], + } + ) return representations