diff --git a/README.md b/README.md index 830c71f..59da5ee 100644 --- a/README.md +++ b/README.md @@ -405,23 +405,27 @@ If you do like this work, then you can support it financially on [Patreon](https + + + + + Featured on Hacker News - DeepFace - A Lightweight Deep Face Recognition Library for Python | Product Hunt +--> ## Citation diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 3abe6db..30d8910 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -2,7 +2,7 @@ import os import warnings import logging -from typing import Any, Dict, IO, List, Union, Optional +from typing import Any, Dict, IO, List, Union, Optional, Sequence # this has to be set before importing tensorflow os.environ["TF_USE_LEGACY_KERAS"] = "1" @@ -174,7 +174,7 @@ def analyze( expand_percentage: int = 0, silent: bool = False, anti_spoofing: bool = False, -) -> List[Dict[str, Any]]: +) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """ Analyze facial attributes such as age, gender, emotion, and race in the provided image. Args: @@ -206,7 +206,10 @@ def analyze( anti_spoofing (boolean): Flag to enable anti spoofing (default is False). Returns: - results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents + (List[List[Dict[str, Any]]]): A list of analysis results if received batched image, + explained below. + + (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents the analysis results for a detected face. Each dictionary in the list contains the following keys: @@ -373,7 +376,7 @@ def find( def represent( - img_path: Union[str, np.ndarray, IO[bytes]], + img_path: Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]], model_name: str = "VGG-Face", enforce_detection: bool = True, detector_backend: str = "opencv", @@ -382,15 +385,18 @@ def represent( normalization: str = "base", anti_spoofing: bool = False, max_faces: Optional[int] = None, -) -> List[Dict[str, Any]]: +) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """ Represent facial images as multi-dimensional vector embeddings. Args: - img_path (str or np.ndarray or IO[bytes]): The exact path to the image, a numpy array + img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]): + The exact path to the image, a numpy array in BGR format, a file object that supports at least `.read` and is opened in binary mode, or a base64 encoded image. If the source image contains multiple faces, - the result will include information for each detected face. + the result will include information for each detected face. If a sequence is provided, + each element should be a string or numpy array representing an image, and the function + will process images in batch. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet @@ -417,8 +423,9 @@ def represent( max_faces (int): Set a limit on the number of faces to be processed (default is None). Returns: - results (List[Dict[str, Any]]): A list of dictionaries, each containing the - following fields: + results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries. + Result type becomes List of List of Dict if batch input passed. + Each containing the following fields: - embedding (List[float]): Multidimensional vector representing facial features. The number of dimensions varies based on the reference model diff --git a/deepface/models/Demography.py b/deepface/models/Demography.py index ad93920..0d8a2de 100644 --- a/deepface/models/Demography.py +++ b/deepface/models/Demography.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, List from abc import ABC, abstractmethod import numpy as np from deepface.commons import package_utils @@ -18,5 +18,51 @@ class Demography(ABC): model_name: str @abstractmethod - def predict(self, img: np.ndarray) -> Union[np.ndarray, np.float64]: + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.float64]: pass + + def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray: + """ + Predict for single image or batched images. + This method uses legacy method while receiving single image as input. + And switch to batch prediction if receives batched images. + + Args: + img_batch: + Batch of images as np.ndarray (n, x, y, c) + with n >= 1, x = image width, y = image height, c = channel + Or Single image as np.ndarray (1, x, y, c) + with x = image width, y = image height and c = channel + The channel dimension will be 1 if input is grayscale. (For emotion model) + """ + if not self.model_name: # Check if called from derived class + raise NotImplementedError("no model selected") + assert img_batch.ndim == 4, "expected 4-dimensional tensor input" + + if img_batch.shape[0] == 1: # Single image + # Predict with legacy method. + return self.model(img_batch, training=False).numpy()[0, :] + + # Batch of images + # Predict with batch prediction + return self.model.predict_on_batch(img_batch) + + def _preprocess_batch_or_single_input( + self, img: Union[np.ndarray, List[np.ndarray]] + ) -> np.ndarray: + """ + Preprocess single or batch of images, return as 4-D numpy array. + Args: + img: Single image as np.ndarray (224, 224, 3) or + List of images as List[np.ndarray] or + Batch of images as np.ndarray (n, 224, 224, 3) + Returns: + Four-dimensional numpy array (n, 224, 224, 3) + """ + image_batch = np.array(img) + + # Check input dimension + if len(image_batch.shape) == 3: + # Single image - add batch dimension + image_batch = np.expand_dims(image_batch, axis=0) + return image_batch diff --git a/deepface/models/FacialRecognition.py b/deepface/models/FacialRecognition.py index a6ee7b5..410a033 100644 --- a/deepface/models/FacialRecognition.py +++ b/deepface/models/FacialRecognition.py @@ -18,7 +18,7 @@ class FacialRecognition(ABC): input_shape: Tuple[int, int] output_shape: int - def forward(self, img: np.ndarray) -> List[float]: + def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]: if not isinstance(self.model, Model): raise ValueError( "You must overwrite forward method if it is not a keras model," @@ -26,4 +26,9 @@ class FacialRecognition(ABC): ) # model.predict causes memory issue when it is called in a for loop # embedding = model.predict(img, verbose=0)[0].tolist() - return self.model(img, training=False).numpy()[0].tolist() + if img.shape == 4 and img.shape[0] == 1: + img = img[0] + embeddings = self.model(img, training=False).numpy() + if embeddings.shape[0] == 1: + return embeddings[0].tolist() + return embeddings.tolist() diff --git a/deepface/models/demography/Age.py b/deepface/models/demography/Age.py index 67ab3ae..f5a56c6 100644 --- a/deepface/models/demography/Age.py +++ b/deepface/models/demography/Age.py @@ -1,3 +1,7 @@ +# stdlib dependencies + +from typing import List, Union + # 3rd party dependencies import numpy as np @@ -9,7 +13,6 @@ from deepface.commons.logger import Logger logger = Logger() -# ---------------------------------------- # dependency configurations tf_version = package_utils.get_tf_major_version() @@ -21,12 +24,11 @@ else: from tensorflow.keras.models import Model, Sequential from tensorflow.keras.layers import Convolution2D, Flatten, Activation -# ---------------------------------------- - WEIGHTS_URL = ( "https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5" ) + # pylint: disable=too-few-public-methods class ApparentAgeClient(Demography): """ @@ -37,11 +39,28 @@ class ApparentAgeClient(Demography): self.model = load_model() self.model_name = "Age" - def predict(self, img: np.ndarray) -> np.float64: - # model.predict causes memory issue when it is called in a for loop - # age_predictions = self.model.predict(img, verbose=0)[0, :] - age_predictions = self.model(img, training=False).numpy()[0, :] - return find_apparent_age(age_predictions) + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, np.ndarray]: + """ + Predict apparent age(s) for single or multiple faces + Args: + img: Single image as np.ndarray (224, 224, 3) or + List of images as List[np.ndarray] or + Batch of images as np.ndarray (n, 224, 224, 3) + Returns: + np.ndarray (age_classes,) if single image, + np.ndarray (n, age_classes) if batched images. + """ + # Preprocessing input image or image list. + imgs = self._preprocess_batch_or_single_input(img) + + # Prediction from 3 channels image + age_predictions = self._predict_internal(imgs) + + # Calculate apparent ages + if len(age_predictions.shape) == 1: # Single prediction list + return find_apparent_age(age_predictions) + + return np.array([find_apparent_age(age_prediction) for age_prediction in age_predictions]) def load_model( @@ -65,7 +84,7 @@ def load_model( # -------------------------- - age_model = Model(inputs=model.input, outputs=base_model_output) + age_model = Model(inputs=model.inputs, outputs=base_model_output) # -------------------------- @@ -83,10 +102,14 @@ def find_apparent_age(age_predictions: np.ndarray) -> np.float64: """ Find apparent age prediction from a given probas of ages Args: - age_predictions (?) + age_predictions (age_classes,) Returns: apparent_age (float) """ + assert ( + len(age_predictions.shape) == 1 + ), f"Input should be a list of predictions, \ + not batched. Got shape: {age_predictions.shape}" output_indexes = np.arange(0, 101) apparent_age = np.sum(age_predictions * output_indexes) return apparent_age diff --git a/deepface/models/demography/Emotion.py b/deepface/models/demography/Emotion.py index d2633b5..499c246 100644 --- a/deepface/models/demography/Emotion.py +++ b/deepface/models/demography/Emotion.py @@ -1,3 +1,6 @@ +# stdlib dependencies +from typing import List, Union + # 3rd party dependencies import numpy as np import cv2 @@ -43,16 +46,38 @@ class EmotionClient(Demography): self.model = load_model() self.model_name = "Emotion" - def predict(self, img: np.ndarray) -> np.ndarray: - img_gray = cv2.cvtColor(img[0], cv2.COLOR_BGR2GRAY) + def _preprocess_image(self, img: np.ndarray) -> np.ndarray: + """ + Preprocess single image for emotion detection + Args: + img: Input image (224, 224, 3) + Returns: + Preprocessed grayscale image (48, 48) + """ + img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img_gray = cv2.resize(img_gray, (48, 48)) - img_gray = np.expand_dims(img_gray, axis=0) + return img_gray - # model.predict causes memory issue when it is called in a for loop - # emotion_predictions = self.model.predict(img_gray, verbose=0)[0, :] - emotion_predictions = self.model(img_gray, training=False).numpy()[0, :] + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + """ + Predict emotion probabilities for single or multiple faces + Args: + img: Single image as np.ndarray (224, 224, 3) or + List of images as List[np.ndarray] or + Batch of images as np.ndarray (n, 224, 224, 3) + Returns: + np.ndarray (n, n_emotions) + where n_emotions is the number of emotion categories + """ + # Preprocessing input image or image list. + imgs = self._preprocess_batch_or_single_input(img) - return emotion_predictions + processed_imgs = np.expand_dims(np.array([self._preprocess_image(img) for img in imgs]), axis=-1) + + # Prediction + predictions = self._predict_internal(processed_imgs) + + return predictions def load_model( diff --git a/deepface/models/demography/Gender.py b/deepface/models/demography/Gender.py index ad1c15e..b6a3ef1 100644 --- a/deepface/models/demography/Gender.py +++ b/deepface/models/demography/Gender.py @@ -1,3 +1,7 @@ +# stdlib dependencies + +from typing import List, Union + # 3rd party dependencies import numpy as np @@ -37,11 +41,23 @@ class GenderClient(Demography): self.model = load_model() self.model_name = "Gender" - def predict(self, img: np.ndarray) -> np.ndarray: - # model.predict causes memory issue when it is called in a for loop - # return self.model.predict(img, verbose=0)[0, :] - return self.model(img, training=False).numpy()[0, :] + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + """ + Predict gender probabilities for single or multiple faces + Args: + img: Single image as np.ndarray (224, 224, 3) or + List of images as List[np.ndarray] or + Batch of images as np.ndarray (n, 224, 224, 3) + Returns: + np.ndarray (n, 2) + """ + # Preprocessing input image or image list. + imgs = self._preprocess_batch_or_single_input(img) + # Prediction + predictions = self._predict_internal(imgs) + + return predictions def load_model( url=WEIGHTS_URL, @@ -64,7 +80,7 @@ def load_model( # -------------------------- - gender_model = Model(inputs=model.input, outputs=base_model_output) + gender_model = Model(inputs=model.inputs, outputs=base_model_output) # -------------------------- diff --git a/deepface/models/demography/Race.py b/deepface/models/demography/Race.py index 2334c8b..eae5154 100644 --- a/deepface/models/demography/Race.py +++ b/deepface/models/demography/Race.py @@ -1,3 +1,6 @@ +# stdlib dependencies +from typing import List, Union + # 3rd party dependencies import numpy as np @@ -37,10 +40,24 @@ class RaceClient(Demography): self.model = load_model() self.model_name = "Race" - def predict(self, img: np.ndarray) -> np.ndarray: - # model.predict causes memory issue when it is called in a for loop - # return self.model.predict(img, verbose=0)[0, :] - return self.model(img, training=False).numpy()[0, :] + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + """ + Predict race probabilities for single or multiple faces + Args: + img: Single image as np.ndarray (224, 224, 3) or + List of images as List[np.ndarray] or + Batch of images as np.ndarray (n, 224, 224, 3) + Returns: + np.ndarray (n, n_races) + where n_races is the number of race categories + """ + # Preprocessing input image or image list. + imgs = self._preprocess_batch_or_single_input(img) + + # Prediction + predictions = self._predict_internal(imgs) + + return predictions def load_model( @@ -62,7 +79,7 @@ def load_model( # -------------------------- - race_model = Model(inputs=model.input, outputs=base_model_output) + race_model = Model(inputs=model.inputs, outputs=base_model_output) # -------------------------- diff --git a/deepface/models/facial_recognition/Dlib.py b/deepface/models/facial_recognition/Dlib.py index 7b29dec..a2e5ca6 100644 --- a/deepface/models/facial_recognition/Dlib.py +++ b/deepface/models/facial_recognition/Dlib.py @@ -1,5 +1,5 @@ # built-in dependencies -from typing import List +from typing import List, Union # 3rd party dependencies import numpy as np @@ -26,24 +26,22 @@ class DlibClient(FacialRecognition): self.input_shape = (150, 150) self.output_shape = 128 - def forward(self, img: np.ndarray) -> List[float]: + def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]: """ Find embeddings with Dlib model. This model necessitates the override of the forward method because it is not a keras model. Args: - img (np.ndarray): pre-loaded image in BGR + img (np.ndarray): pre-loaded image(s) in BGR Returns - embeddings (list): multi-dimensional vector + embeddings (list of lists or list of floats): multi-dimensional vectors """ - # return self.model.predict(img)[0].tolist() - - # extract_faces returns 4 dimensional images - if len(img.shape) == 4: - img = img[0] + # Handle single image case + if len(img.shape) == 3: + img = np.expand_dims(img, axis=0) # bgr to rgb - img = img[:, :, ::-1] # bgr to rgb + img = img[:, :, :, ::-1] # bgr to rgb # img is in scale of [0, 1] but expected [0, 255] if img.max() <= 1: @@ -51,10 +49,11 @@ class DlibClient(FacialRecognition): img = img.astype(np.uint8) - img_representation = self.model.model.compute_face_descriptor(img) - img_representation = np.array(img_representation) - img_representation = np.expand_dims(img_representation, axis=0) - return img_representation[0].tolist() + embeddings = self.model.model.compute_face_descriptor(img) + embeddings = [np.array(embedding).tolist() for embedding in embeddings] + if len(embeddings) == 1: + return embeddings[0] + return embeddings class DlibResNet: diff --git a/deepface/models/facial_recognition/SFace.py b/deepface/models/facial_recognition/SFace.py index f6a01ca..eeebbe3 100644 --- a/deepface/models/facial_recognition/SFace.py +++ b/deepface/models/facial_recognition/SFace.py @@ -1,5 +1,5 @@ # built-in dependencies -from typing import Any, List +from typing import Any, List, Union # 3rd party dependencies import numpy as np @@ -27,7 +27,7 @@ class SFaceClient(FacialRecognition): self.input_shape = (112, 112) self.output_shape = 128 - def forward(self, img: np.ndarray) -> List[float]: + def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]: """ Find embeddings with SFace model This model necessitates the override of the forward method @@ -37,14 +37,17 @@ class SFaceClient(FacialRecognition): Returns embeddings (list): multi-dimensional vector """ - # return self.model.predict(img)[0].tolist() + input_blob = (img * 255).astype(np.uint8) - # revert the image to original format and preprocess using the model - input_blob = (img[0] * 255).astype(np.uint8) + embeddings = [] + for i in range(input_blob.shape[0]): + embedding = self.model.model.feature(input_blob[i]) + embeddings.append(embedding) + embeddings = np.concatenate(embeddings, axis=0) - embeddings = self.model.model.feature(input_blob) - - return embeddings[0].tolist() + if embeddings.shape[0] == 1: + return embeddings[0].tolist() + return embeddings.tolist() def load_model( diff --git a/deepface/models/facial_recognition/VGGFace.py b/deepface/models/facial_recognition/VGGFace.py index bfcbcad..bffd0d6 100644 --- a/deepface/models/facial_recognition/VGGFace.py +++ b/deepface/models/facial_recognition/VGGFace.py @@ -57,8 +57,7 @@ class VggFaceClient(FacialRecognition): def forward(self, img: np.ndarray) -> List[float]: """ Generates embeddings using the VGG-Face model. - This method incorporates an additional normalization layer, - necessitating the override of the forward method. + This method incorporates an additional normalization layer. Args: img (np.ndarray): pre-loaded image in BGR @@ -70,8 +69,14 @@ class VggFaceClient(FacialRecognition): # having normalization layer in descriptor troubles for some gpu users (e.g. issue 957, 966) # instead we are now calculating it with traditional way not with keras backend - embedding = self.model(img, training=False).numpy()[0].tolist() - embedding = verification.l2_normalize(embedding) + embedding = super().forward(img) + if ( + isinstance(embedding, list) and + isinstance(embedding[0], list) + ): + embedding = verification.l2_normalize(embedding, axis=1) + else: + embedding = verification.l2_normalize(embedding) return embedding.tolist() diff --git a/deepface/modules/demography.py b/deepface/modules/demography.py index 2258c1e..d3ce8e6 100644 --- a/deepface/modules/demography.py +++ b/deepface/modules/demography.py @@ -100,6 +100,29 @@ def analyze( - 'white': Confidence score for White ethnicity. """ + if isinstance(img_path, np.ndarray) and len(img_path.shape) == 4: + # Received 4-D array, which means image batch. + # Check batch dimension and process each image separately. + if img_path.shape[0] > 1: + batch_resp_obj = [] + # Execute analysis for each image in the batch. + for single_img in img_path: + # Call the analyze function for each image in the batch. + resp_obj = analyze( + img_path=single_img, + actions=actions, + enforce_detection=enforce_detection, + detector_backend=detector_backend, + align=align, + expand_percentage=expand_percentage, + silent=silent, + anti_spoofing=anti_spoofing, + ) + + # Append the response object to the batch response list. + batch_resp_obj.append(resp_obj) + return batch_resp_obj + # if actions is passed as tuple with single item, interestingly it becomes str here if isinstance(actions, str): actions = (actions,) diff --git a/deepface/modules/representation.py b/deepface/modules/representation.py index d0c4fc3..2be4fec 100644 --- a/deepface/modules/representation.py +++ b/deepface/modules/representation.py @@ -1,5 +1,5 @@ # built-in dependencies -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, List, Union, Optional, Sequence, IO # 3rd party dependencies import numpy as np @@ -11,7 +11,7 @@ from deepface.models.FacialRecognition import FacialRecognition def represent( - img_path: Union[str, np.ndarray], + img_path: Union[str, IO[bytes], np.ndarray, Sequence[Union[str, np.ndarray, IO[bytes]]]], model_name: str = "VGG-Face", enforce_detection: bool = True, detector_backend: str = "opencv", @@ -20,14 +20,16 @@ def represent( normalization: str = "base", anti_spoofing: bool = False, max_faces: Optional[int] = None, -) -> List[Dict[str, Any]]: +) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """ Represent facial images as multi-dimensional vector embeddings. Args: - img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, - or a base64 encoded image. If the source image contains multiple faces, the result will - include information for each detected face. + img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]): + The exact path to the image, a numpy array in BGR format, + a base64 encoded image, or a sequence of these. + If the source image contains multiple faces, + the result will include information for each detected face. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet @@ -51,8 +53,9 @@ def represent( max_faces (int): Set a limit on the number of faces to be processed (default is None). Returns: - results (List[Dict[str, Any]]): A list of dictionaries, each containing the - following fields: + results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries. + Result type becomes List of List of Dict if batch input passed. + Each containing the following fields: - embedding (List[float]): Multidimensional vector representing facial features. The number of dimensions varies based on the reference model @@ -70,80 +73,105 @@ def represent( task="facial_recognition", model_name=model_name ) - # --------------------------------- - # we have run pre-process in verification. so, this can be skipped if it is coming from verify. - target_size = model.input_shape - if detector_backend != "skip": - # Images are returned in RGB format. - img_objs = detection.extract_faces( - img_path=img_path, - detector_backend=detector_backend, - grayscale=False, - enforce_detection=enforce_detection, - align=align, - expand_percentage=expand_percentage, - anti_spoofing=anti_spoofing, - max_faces=max_faces, - ) - else: # skip - # Try load. If load error, will raise exception internal - img, _ = image_utils.load_image(img_path) + # Handle list of image paths or 4D numpy array + if isinstance(img_path, list): + images = img_path + elif isinstance(img_path, np.ndarray) and img_path.ndim == 4: + images = [img_path[i] for i in range(img_path.shape[0])] + else: + images = [img_path] - if len(img.shape) != 3: - raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}") + batch_images, batch_regions, batch_confidences, batch_indexes = [], [], [], [] - # Convert to RGB format to keep compatability with `extract_faces`. - img = img[:, :, ::-1] + for idx, single_img_path in enumerate(images): + # we have run pre-process in verification. so, skip if it is coming from verify. + target_size = model.input_shape + if detector_backend != "skip": + # Images are returned in RGB format. + img_objs = detection.extract_faces( + img_path=single_img_path, + detector_backend=detector_backend, + grayscale=False, + enforce_detection=enforce_detection, + align=align, + expand_percentage=expand_percentage, + anti_spoofing=anti_spoofing, + max_faces=max_faces, + ) + else: # skip + # Try load. If load error, will raise exception internal + img, _ = image_utils.load_image(single_img_path) - # make dummy region and confidence to keep compatibility with `extract_faces` - img_objs = [ - { - "face": img, - "facial_area": {"x": 0, "y": 0, "w": img.shape[0], "h": img.shape[1]}, - "confidence": 0, - } - ] - # --------------------------------- + if len(img.shape) != 3: + raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}") - if max_faces is not None and max_faces < len(img_objs): - # sort as largest facial areas come first - img_objs = sorted( - img_objs, - key=lambda img_obj: img_obj["facial_area"]["w"] * img_obj["facial_area"]["h"], - reverse=True, - ) - # discard rest of the items - img_objs = img_objs[0:max_faces] + # Convert to RGB format to keep compatability with `extract_faces`. + img = img[:, :, ::-1] - for img_obj in img_objs: - if anti_spoofing is True and img_obj.get("is_real", True) is False: - raise ValueError("Spoof detected in the given image.") - img = img_obj["face"] + # make dummy region and confidence to keep compatibility with `extract_faces` + img_objs = [ + { + "face": img, + "facial_area": {"x": 0, "y": 0, "w": img.shape[0], "h": img.shape[1]}, + "confidence": 0, + } + ] + # --------------------------------- - # rgb to bgr - img = img[:, :, ::-1] + if max_faces is not None and max_faces < len(img_objs): + # sort as largest facial areas come first + img_objs = sorted( + img_objs, + key=lambda img_obj: img_obj["facial_area"]["w"] * img_obj["facial_area"]["h"], + reverse=True, + ) + # discard rest of the items + img_objs = img_objs[0:max_faces] - region = img_obj["facial_area"] - confidence = img_obj["confidence"] + for img_obj in img_objs: + if anti_spoofing is True and img_obj.get("is_real", True) is False: + raise ValueError("Spoof detected in the given image.") - # resize to expected shape of ml model - img = preprocessing.resize_image( - img=img, - # thanks to DeepId (!) - target_size=(target_size[1], target_size[0]), - ) + img = img_obj["face"] - # custom normalization - img = preprocessing.normalize_input(img=img, normalization=normalization) + # rgb to bgr + img = img[:, :, ::-1] - embedding = model.forward(img) + region = img_obj["facial_area"] + confidence = img_obj["confidence"] - resp_objs.append( - { - "embedding": embedding, - "facial_area": region, - "face_confidence": confidence, - } - ) + # resize to expected shape of ml model + img = preprocessing.resize_image( + img=img, + # thanks to DeepId (!) + target_size=(target_size[1], target_size[0]), + ) - return resp_objs + # custom normalization + img = preprocessing.normalize_input(img=img, normalization=normalization) + + batch_images.append(img) + batch_regions.append(region) + batch_confidences.append(confidence) + batch_indexes.append(idx) + + # Convert list of images to a numpy array for batch processing + batch_images = np.concatenate(batch_images, axis=0) + + # Forward pass through the model for the entire batch + embeddings = model.forward(batch_images) + + for idx in range(0, len(images)): + resp_obj = [] + for idy, batch_index in enumerate(batch_indexes): + if idx == batch_index: + resp_obj.append( + { + "embedding": embeddings if len(batch_images) == 1 else embeddings[idy], + "facial_area": batch_regions[idy], + "face_confidence": batch_confidences[idy], + } + ) + resp_objs.append(resp_obj) + + return resp_objs[0] if len(images) == 1 else resp_objs diff --git a/icon/github_sponsor_button.png b/icon/github_sponsor_button.png new file mode 100644 index 0000000..e011d89 Binary files /dev/null and b/icon/github_sponsor_button.png differ diff --git a/tests/test_analyze.py b/tests/test_analyze.py index bad4426..6f8c996 100644 --- a/tests/test_analyze.py +++ b/tests/test_analyze.py @@ -1,8 +1,10 @@ # 3rd party dependencies import cv2 +import numpy as np # project dependencies from deepface import DeepFace +from deepface.models.demography import Age, Emotion, Gender, Race from deepface.commons.logger import Logger logger = Logger() @@ -16,6 +18,7 @@ def test_standard_analyze(): demography_objs = DeepFace.analyze(img, silent=True) for demography in demography_objs: logger.debug(demography) + assert type(demography) == dict assert demography["age"] > 20 and demography["age"] < 40 assert demography["dominant_gender"] == "Woman" logger.info("✅ test standard analyze done") @@ -29,6 +32,7 @@ def test_analyze_with_all_actions_as_tuple(): for demography in demography_objs: logger.debug(f"Demography: {demography}") + assert type(demography) == dict age = demography["age"] gender = demography["dominant_gender"] race = demography["dominant_race"] @@ -53,6 +57,7 @@ def test_analyze_with_all_actions_as_list(): for demography in demography_objs: logger.debug(f"Demography: {demography}") + assert type(demography) == dict age = demography["age"] gender = demography["dominant_gender"] race = demography["dominant_race"] @@ -74,6 +79,7 @@ def test_analyze_for_some_actions(): demography_objs = DeepFace.analyze(img, ["age", "gender"], silent=True) for demography in demography_objs: + assert type(demography) == dict age = demography["age"] gender = demography["dominant_gender"] @@ -95,6 +101,7 @@ def test_analyze_for_preloaded_image(): resp_objs = DeepFace.analyze(img, silent=True) for resp_obj in resp_objs: logger.debug(resp_obj) + assert type(resp_obj) == dict assert resp_obj["age"] > 20 and resp_obj["age"] < 40 assert resp_obj["dominant_gender"] == "Woman" @@ -131,7 +138,89 @@ def test_analyze_for_different_detectors(): ] # validate probabilities + assert type(result) == dict if result["dominant_gender"] == "Man": assert result["gender"]["Man"] > result["gender"]["Woman"] else: assert result["gender"]["Man"] < result["gender"]["Woman"] + + +def test_analyze_for_numpy_batched_image(): + img1_path = "dataset/img4.jpg" + img2_path = "dataset/couple.jpg" + + # Copy and combine the same image to create multiple faces + img1 = cv2.imread(img1_path) + img2 = cv2.imread(img2_path) + + expected_num_faces = [1, 2] + + img1 = cv2.resize(img1, (500, 500)) + img2 = cv2.resize(img2, (500, 500)) + + img = np.stack([img1, img2]) + assert len(img.shape) == 4 # Check dimension. + assert img.shape[0] == 2 # Check batch size. + + demography_batch = DeepFace.analyze(img, silent=True) + # 2 image in batch, so 2 demography objects. + assert len(demography_batch) == 2 + + for i, demography_objs in enumerate(demography_batch): + + assert len(demography_objs) == expected_num_faces[i] + for demography in demography_objs: # Iterate over faces + assert isinstance(demography, dict) # Check type + assert demography["age"] > 20 and demography["age"] < 40 + assert demography["dominant_gender"] in ["Woman", "Man"] + + logger.info("✅ test analyze for multiple faces done") + + +def test_batch_detect_age_for_multiple_faces(): + # Load test image and resize to model input size + img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) + imgs = [img, img] + results = Age.ApparentAgeClient().predict(imgs) + # Check there are two ages detected + assert len(results) == 2 + # Check two faces ages are the same in integer format(e.g. 23.6 -> 23) + # Must use int() to compare because of max float precision issue in different platforms + assert np.array_equal(int(results[0]), int(results[1])) + logger.info("✅ test batch detect age for multiple faces done") + + +def test_batch_detect_emotion_for_multiple_faces(): + # Load test image and resize to model input size + img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) + imgs = [img, img] + results = Emotion.EmotionClient().predict(imgs) + # Check there are two emotions detected + assert len(results) == 2 + # Check two faces emotions are the same + assert np.array_equal(results[0], results[1]) + logger.info("✅ test batch detect emotion for multiple faces done") + + +def test_batch_detect_gender_for_multiple_faces(): + # Load test image and resize to model input size + img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) + imgs = [img, img] + results = Gender.GenderClient().predict(imgs) + # Check there are two genders detected + assert len(results) == 2 + # Check two genders are the same + assert np.array_equal(results[0], results[1]) + logger.info("✅ test batch detect gender for multiple faces done") + + +def test_batch_detect_race_for_multiple_faces(): + # Load test image and resize to model input size + img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) + imgs = [img, img] + results = Race.RaceClient().predict(imgs) + # Check there are two races detected + assert len(results) == 2 + # Check two races are the same + assert np.array_equal(results[0], results[1]) + logger.info("✅ test batch detect race for multiple faces done") diff --git a/tests/test_represent.py b/tests/test_represent.py index e3c52b8..e5a7eab 100644 --- a/tests/test_represent.py +++ b/tests/test_represent.py @@ -2,6 +2,8 @@ import io import cv2 import pytest +import numpy as np +import pytest # project dependencies from deepface import DeepFace @@ -13,7 +15,12 @@ logger = Logger() def test_standard_represent(): img_path = "dataset/img1.jpg" embedding_objs = DeepFace.represent(img_path) + # type should be list of dict + assert isinstance(embedding_objs, list) + for embedding_obj in embedding_objs: + assert isinstance(embedding_obj, dict) + embedding = embedding_obj["embedding"] logger.debug(f"Function returned {len(embedding)} dimensional vector") assert len(embedding) == 4096 @@ -23,18 +30,18 @@ def test_standard_represent(): def test_standard_represent_with_io_object(): img_path = "dataset/img1.jpg" default_embedding_objs = DeepFace.represent(img_path) - io_embedding_objs = DeepFace.represent(open(img_path, 'rb')) + io_embedding_objs = DeepFace.represent(open(img_path, "rb")) assert default_embedding_objs == io_embedding_objs # Confirm non-seekable io objects are handled properly - io_obj = io.BytesIO(open(img_path, 'rb').read()) + io_obj = io.BytesIO(open(img_path, "rb").read()) io_obj.seek = None no_seek_io_embedding_objs = DeepFace.represent(io_obj) assert default_embedding_objs == no_seek_io_embedding_objs # Confirm non-image io objects raise exceptions - with pytest.raises(ValueError, match='Failed to decode image'): - DeepFace.represent(io.BytesIO(open(r'../requirements.txt', 'rb').read())) + with pytest.raises(ValueError, match="Failed to decode image"): + DeepFace.represent(io.BytesIO(open(r"../requirements.txt", "rb").read())) logger.info("✅ test standard represent with io object function done") @@ -55,6 +62,27 @@ def test_represent_for_skipped_detector_backend_with_image_path(): logger.info("✅ test represent function for skipped detector and image path input backend done") +def test_represent_for_preloaded_image(): + face_img = "dataset/img5.jpg" + img = cv2.imread(face_img) + img_objs = DeepFace.represent(img_path=img) + # type should be list of dict + assert isinstance(img_objs, list) + assert len(img_objs) >= 1 + + for img_obj in img_objs: + assert isinstance(img_obj, dict) + assert "embedding" in img_obj.keys() + assert "facial_area" in img_obj.keys() + assert isinstance(img_obj["facial_area"], dict) + assert "x" in img_obj["facial_area"].keys() + assert "y" in img_obj["facial_area"].keys() + assert "w" in img_obj["facial_area"].keys() + assert "h" in img_obj["facial_area"].keys() + assert "face_confidence" in img_obj.keys() + logger.info("✅ test represent function for skipped detector and preloaded image done") + + def test_represent_for_skipped_detector_backend_with_preloaded_image(): face_img = "dataset/img5.jpg" img = cv2.imread(face_img) @@ -84,12 +112,6 @@ def test_max_faces(): def test_represent_detector_backend(): - """ - There shouldn't be a difference between: - - Using a detector backend provided by `represent` - - Manually calling a detector backend, then calling `represent`. - """ - # Results using a detection backend. results_1 = DeepFace.represent(img_path="dataset/img1.jpg") assert len(results_1) == 1 @@ -108,3 +130,108 @@ def test_represent_detector_backend(): embedding_2 = results_2[0]['embedding'] assert embedding_1 == embedding_2 logger.info("✅ test represent function for consistent output.") + + +@pytest.mark.parametrize( + "model_name", + [ + "VGG-Face", + "Facenet", + "SFace", + ], +) +def test_batched_represent_for_list_input(model_name): + img_paths = [ + "dataset/img1.jpg", + "dataset/img2.jpg", + "dataset/img3.jpg", + "dataset/img4.jpg", + "dataset/img5.jpg", + "dataset/couple.jpg", + ] + + expected_faces = [1, 1, 1, 1, 1, 2] + + batched_embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name) + + # type should be list of list of dict for batch input + assert isinstance(batched_embedding_objs, list) + + assert len(batched_embedding_objs) == len( + img_paths + ), f"Expected {len(img_paths)} embeddings, got {len(batched_embedding_objs)}" + + # the last one has two faces + for idx, embedding_objs in enumerate(batched_embedding_objs): + # type should be list of list of dict for batch input + # batched_embedding_objs was list already, embedding_objs should be list of dict + assert isinstance(embedding_objs, list) + for embedding_obj in embedding_objs: + assert isinstance(embedding_obj, dict) + + assert expected_faces[idx] == len( + embedding_objs + ), f"{img_paths[idx]} has {expected_faces[idx]} faces, but got {len(embedding_objs)} embeddings!" + + for idx, img_path in enumerate(img_paths): + single_embedding_objs = DeepFace.represent(img_path=img_path, model_name=model_name) + # type should be list of dict for single input + assert isinstance(single_embedding_objs, list) + for embedding_obj in single_embedding_objs: + assert isinstance(embedding_obj, dict) + + assert len(single_embedding_objs) == len(batched_embedding_objs[idx]) + + for alpha, beta in zip(single_embedding_objs, batched_embedding_objs[idx]): + assert np.allclose( + alpha["embedding"], beta["embedding"], rtol=1e-2, atol=1e-2 + ), "Embeddings do not match within tolerance" + + logger.info(f"✅ test batch represent function with string input for model {model_name} done") + + +@pytest.mark.parametrize( + "model_name", + [ + "VGG-Face", + "Facenet", + "SFace", + ], +) +def test_batched_represent_for_numpy_input(model_name): + img_paths = [ + "dataset/img1.jpg", + "dataset/img2.jpg", + "dataset/img3.jpg", + "dataset/img4.jpg", + "dataset/img5.jpg", + "dataset/couple.jpg", + ] + expected_faces = [1, 1, 1, 1, 1, 2] + + imgs = [] + for img_path in img_paths: + img = cv2.imread(img_path) + img = cv2.resize(img, (1000, 1000)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # print(img.shape) + imgs.append(img) + + imgs = np.array(imgs) + assert imgs.ndim == 4 and imgs.shape[0] == len(img_paths) + + batched_embedding_objs = DeepFace.represent(img_path=imgs, model_name=model_name) + + # type should be list of list of dict for batch input + assert isinstance(batched_embedding_objs, list) + for idx, batched_embedding_obj in enumerate(batched_embedding_objs): + assert isinstance(batched_embedding_obj, list) + # it also has to have the expected number of faces + assert len(batched_embedding_obj) == expected_faces[idx] + for embedding_obj in batched_embedding_obj: + assert isinstance(embedding_obj, dict) + + # we should have the same number of embeddings as the number of images + assert len(batched_embedding_objs) == len(img_paths) + + logger.info(f"✅ test batch represent function with numpy input for model {model_name} done")