From 6ab7d190b89fac3baadb99f47a79a64aca5c877b Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 3 Aug 2024 16:20:24 +0400 Subject: [PATCH] few improvements --- deepface/DeepFace.py | 4 +++ deepface/commons/image_utils.py | 47 +++++++++++++++++++++------------ deepface/modules/recognition.py | 12 ++++++++- 3 files changed, 45 insertions(+), 18 deletions(-) diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 6813691..b11a3dd 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -270,6 +270,7 @@ def find( silent: bool = False, refresh_database: bool = True, anti_spoofing: bool = False, + recursive: bool = True, ) -> List[pd.DataFrame]: """ Identify individuals in a database @@ -281,6 +282,8 @@ def find( db_path (string): Path to the folder containing image files. All detected faces in the database will be considered in the decision-making process. + recursive (bool): Walk db_path recursively (default True) + model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). @@ -347,6 +350,7 @@ def find( silent=silent, refresh_database=refresh_database, anti_spoofing=anti_spoofing, + recursive=recursive, ) diff --git a/deepface/commons/image_utils.py b/deepface/commons/image_utils.py index c25e411..e139ebf 100644 --- a/deepface/commons/image_utils.py +++ b/deepface/commons/image_utils.py @@ -13,28 +13,45 @@ import cv2 from PIL import Image -def list_images(path: str) -> List[str]: +def is_image(file_path: str) -> bool: + """ + Check if a file is an image + Args: + file_path (str): path to the file + Returns: + bool: True if the file is an image, False otherwise + """ + _, ext = os.path.splitext(file_path) + ext_lower = ext.lower() + + if ext_lower not in {".jpg", ".jpeg", ".png", ".webp"}: + return False + + with Image.open(file_path) as img: # lazy + return img.format.lower() in ["jpeg", "png"] + + +def list_images(path: str, recursive: bool = True) -> List[str]: """ List images in a given path Args: path (str): path's location + recursive (bool): default True Returns: images (list): list of exact image paths """ images = [] - for r, _, f in os.walk(path): - for file in f: - exact_path = os.path.join(r, file) - - _, ext = os.path.splitext(exact_path) - ext_lower = ext.lower() - - if ext_lower not in {".jpg", ".jpeg", ".png"}: - continue - - with Image.open(exact_path) as img: # lazy - if img.format.lower() in ["jpeg", "png"]: + if recursive: + for r, _, f in os.walk(path): + for file in f: + exact_path = os.path.join(r, file) + if is_image(exact_path): images.append(exact_path) + else: + for file in os.listdir(path): + exact_path = os.path.join(path, file) + if is_image(exact_path): + images.append(exact_path) return images @@ -95,10 +112,6 @@ def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]: # image must be a file on the system then - # image name must have english characters - if img.isascii() is False: - raise ValueError(f"Input image must not have non-english characters - {img}") - img_obj_bgr = cv2.imread(img) # img_obj_rgb = cv2.cvtColor(img_obj_bgr, cv2.COLOR_BGR2RGB) return img_obj_bgr, img diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index 4b94ad9..7161f04 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -31,6 +31,7 @@ def find( silent: bool = False, refresh_database: bool = True, anti_spoofing: bool = False, + recursive: bool = True, ) -> List[pd.DataFrame]: """ Identify individuals in a database @@ -43,6 +44,8 @@ def find( db_path (string): Path to the folder containing image files. All detected faces in the database will be considered in the decision-making process. + recursive (bool): Walk db_path recursively (default True) + model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). @@ -152,7 +155,7 @@ def find( pickled_images = [representation["identity"] for representation in representations] # Get the list of images on storage - storage_images = image_utils.list_images(path=db_path) + storage_images = image_utils.list_images(path=db_path, recursive=recursive) if len(storage_images) == 0 and refresh_database is True: raise ValueError(f"No item found in {db_path}") @@ -374,6 +377,13 @@ def __find_bulk_embeddings( logger.error(f"Exception while extracting faces from {employee}: {str(err)}") img_objs = [] + except KeyboardInterrupt: + needInterrupt = os.getenv("DEEPFACE_KEYBOARD_INTERRUPT", '0').lower() in ('true', '1', 't') + if not needInterrupt: + raise + else: + break + if len(img_objs) == 0: representations.append( {