From 688fbe6b902f5a368142284fad03cc6054bb0d9b Mon Sep 17 00:00:00 2001 From: NatLee Date: Mon, 13 Jan 2025 22:27:11 +0800 Subject: [PATCH] [fix] lint --- deepface/models/Demography.py | 38 +++++++++++++++++++--------------- deepface/modules/demography.py | 8 +++---- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/deepface/models/Demography.py b/deepface/models/Demography.py index fb9d106..329a156 100644 --- a/deepface/models/Demography.py +++ b/deepface/models/Demography.py @@ -28,24 +28,32 @@ class Demography(ABC): And switch to batch prediction if receives batched images. Args: - img_batch: Batch of images as np.ndarray (n, x, y, c), with n >= 1, x = image width, y = image height, c = channel - Or Single image as np.ndarray (1, x, y, c), with x = image width, y = image height and c = channel - The channel dimension may be omitted if the image is grayscale. (For emotion model) + img_batch: + Batch of images as np.ndarray (n, x, y, c) + with n >= 1, x = image width, y = image height, c = channel + Or Single image as np.ndarray (1, x, y, c) + with x = image width, y = image height and c = channel + The channel dimension may be omitted if the image is grayscale. (For emotion model) """ if not self.model_name: # Check if called from derived class raise NotImplementedError("no model selected") - assert img_batch.ndim == 4, "expected 4-dimensional tensor input" + # Single image + if img_batch.shape[0] == 1: + # Check if grayscale by checking last dimension, if not 3, it is grayscale. + if img_batch.shape[-1] != 3: + # Remove batch dimension + img_batch = img_batch.squeeze(0) + # Predict with legacy method. + return self.model(img_batch, training=False).numpy()[0, :] + # Batch of images + # Predict with batch prediction + return self.model.predict_on_batch(img_batch) - if img_batch.shape[0] == 1: # Single image - if img_batch.shape[-1] != 3: # Check if grayscale by checking last dimension, if not 3, it is grayscale. - img_batch = img_batch.squeeze(0) # Remove batch dimension - predict_result = self.model(img_batch, training=False).numpy()[0, :] # Predict with legacy method. - return predict_result - else: # Batch of images - return self.model.predict_on_batch(img_batch) # Predict with batch prediction - - def _preprocess_batch_or_single_input(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + def _preprocess_batch_or_single_input( + self, + img: Union[np.ndarray, List[np.ndarray]] + ) -> np.ndarray: """ Preprocess single or batch of images, return as 4-D numpy array. @@ -56,15 +64,11 @@ class Demography(ABC): Returns: Four-dimensional numpy array (n, 224, 224, 3) """ - image_batch = np.array(img) - # Remove batch dimension in advance if exists image_batch = image_batch.squeeze() - # Check input dimension if len(image_batch.shape) == 3: # Single image - add batch dimension image_batch = np.expand_dims(image_batch, axis=0) - return image_batch diff --git a/deepface/modules/demography.py b/deepface/modules/demography.py index 9789007..c78f73b 100644 --- a/deepface/modules/demography.py +++ b/deepface/modules/demography.py @@ -168,9 +168,9 @@ def analyze( model = modeling.build_model(task="facial_attribute", model_name=action.capitalize()) predictions = model.predict(faces_array) - # If the model returns a single prediction, reshape it to match the number of faces - # Use number of faces and number of predictions shape to determine the correct shape of predictions - # For example, if there are 1 face to predict with Emotion model, reshape predictions to (1, 7) + # If the model returns a single prediction, reshape it to match the number of faces. + # Determine the correct shape of predictions by using number of faces and predictions shape. + # Example: For 1 face with Emotion model, predictions will be reshaped to (1, 7). if faces_array.shape[0] == 1 and len(predictions.shape) == 1: # For models like `Emotion`, which return a single prediction for a single face predictions = predictions.reshape(1, -1) @@ -229,4 +229,4 @@ def analyze( for result, race_result in zip(results, race_results): result.update(race_result) - return results \ No newline at end of file + return results