Merge pull request #1396 from NatLee/feat/batch-predict-age-and-gender

Feat/batch predict age and gender
This commit is contained in:
Sefik Ilkin Serengil 2025-02-16 19:43:17 +00:00 committed by GitHub
commit 112d1892fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 258 additions and 29 deletions

View File

@ -174,7 +174,7 @@ def analyze(
expand_percentage: int = 0, expand_percentage: int = 0,
silent: bool = False, silent: bool = False,
anti_spoofing: bool = False, anti_spoofing: bool = False,
) -> List[Dict[str, Any]]: ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
""" """
Analyze facial attributes such as age, gender, emotion, and race in the provided image. Analyze facial attributes such as age, gender, emotion, and race in the provided image.
Args: Args:
@ -206,7 +206,10 @@ def analyze(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False). anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns: Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents (List[List[Dict[str, Any]]]): A list of analysis results if received batched image,
explained below.
(List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
the analysis results for a detected face. Each dictionary in the list contains the the analysis results for a detected face. Each dictionary in the list contains the
following keys: following keys:

View File

@ -1,4 +1,4 @@
from typing import Union from typing import Union, List
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
from deepface.commons import package_utils from deepface.commons import package_utils
@ -18,5 +18,53 @@ class Demography(ABC):
model_name: str model_name: str
@abstractmethod @abstractmethod
def predict(self, img: np.ndarray) -> Union[np.ndarray, np.float64]: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.float64]:
pass pass
def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray:
"""
Predict for single image or batched images.
This method uses legacy method while receiving single image as input.
And switch to batch prediction if receives batched images.
Args:
img_batch:
Batch of images as np.ndarray (n, x, y, c)
with n >= 1, x = image width, y = image height, c = channel
Or Single image as np.ndarray (1, x, y, c)
with x = image width, y = image height and c = channel
The channel dimension will be 1 if input is grayscale. (For emotion model)
"""
if not self.model_name: # Check if called from derived class
raise NotImplementedError("no model selected")
assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
if img_batch.shape[0] == 1: # Single image
# Predict with legacy method.
return self.model(img_batch, training=False).numpy()[0, :]
# Batch of images
# Predict with batch prediction
return self.model.predict_on_batch(img_batch)
def _preprocess_batch_or_single_input(
self,
img: Union[np.ndarray, List[np.ndarray]]
) -> np.ndarray:
"""
Preprocess single or batch of images, return as 4-D numpy array.
Args:
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
Four-dimensional numpy array (n, 224, 224, 3)
"""
image_batch = np.array(img)
# Check input dimension
if len(image_batch.shape) == 3:
# Single image - add batch dimension
image_batch = np.expand_dims(image_batch, axis=0)
return image_batch

View File

@ -1,3 +1,7 @@
# stdlib dependencies
from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -37,12 +41,30 @@ class ApparentAgeClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Age" self.model_name = "Age"
def predict(self, img: np.ndarray) -> np.float64: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, np.ndarray]:
# model.predict causes memory issue when it is called in a for loop """
# age_predictions = self.model.predict(img, verbose=0)[0, :] Predict apparent age(s) for single or multiple faces
age_predictions = self.model(img, training=False).numpy()[0, :] Args:
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
np.ndarray (age_classes,) if single image,
np.ndarray (n, age_classes) if batched images.
"""
# Preprocessing input image or image list.
imgs = self._preprocess_batch_or_single_input(img)
# Prediction from 3 channels image
age_predictions = self._predict_internal(imgs)
# Calculate apparent ages
if len(age_predictions.shape) == 1: # Single prediction list
return find_apparent_age(age_predictions) return find_apparent_age(age_predictions)
return np.array([
find_apparent_age(age_prediction) for age_prediction in age_predictions])
def load_model( def load_model(
url=WEIGHTS_URL, url=WEIGHTS_URL,
@ -65,7 +87,7 @@ def load_model(
# -------------------------- # --------------------------
age_model = Model(inputs=model.input, outputs=base_model_output) age_model = Model(inputs=model.inputs, outputs=base_model_output)
# -------------------------- # --------------------------
@ -78,15 +100,16 @@ def load_model(
return age_model return age_model
def find_apparent_age(age_predictions: np.ndarray) -> np.float64: def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
""" """
Find apparent age prediction from a given probas of ages Find apparent age prediction from a given probas of ages
Args: Args:
age_predictions (?) age_predictions (age_classes,)
Returns: Returns:
apparent_age (float) apparent_age (float)
""" """
assert len(age_predictions.shape) == 1, f"Input should be a list of predictions, \
not batched. Got shape: {age_predictions.shape}"
output_indexes = np.arange(0, 101) output_indexes = np.arange(0, 101)
apparent_age = np.sum(age_predictions * output_indexes) apparent_age = np.sum(age_predictions * output_indexes)
return apparent_age return apparent_age

View File

@ -1,3 +1,6 @@
# stdlib dependencies
from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
import cv2 import cv2
@ -43,16 +46,38 @@ class EmotionClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Emotion" self.model_name = "Emotion"
def predict(self, img: np.ndarray) -> np.ndarray: def _preprocess_image(self, img: np.ndarray) -> np.ndarray:
img_gray = cv2.cvtColor(img[0], cv2.COLOR_BGR2GRAY) """
Preprocess single image for emotion detection
Args:
img: Input image (224, 224, 3)
Returns:
Preprocessed grayscale image (48, 48)
"""
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_gray = cv2.resize(img_gray, (48, 48)) img_gray = cv2.resize(img_gray, (48, 48))
img_gray = np.expand_dims(img_gray, axis=0) return img_gray
# model.predict causes memory issue when it is called in a for loop def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
# emotion_predictions = self.model.predict(img_gray, verbose=0)[0, :] """
emotion_predictions = self.model(img_gray, training=False).numpy()[0, :] Predict emotion probabilities for single or multiple faces
Args:
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
np.ndarray (n, n_emotions)
where n_emotions is the number of emotion categories
"""
# Preprocessing input image or image list.
imgs = self._preprocess_batch_or_single_input(img)
return emotion_predictions processed_imgs = np.expand_dims(np.array([self._preprocess_image(img) for img in imgs]), axis=-1)
# Prediction
predictions = self._predict_internal(processed_imgs)
return predictions
def load_model( def load_model(

View File

@ -1,3 +1,7 @@
# stdlib dependencies
from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -37,11 +41,23 @@ class GenderClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Gender" self.model_name = "Gender"
def predict(self, img: np.ndarray) -> np.ndarray: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
# model.predict causes memory issue when it is called in a for loop """
# return self.model.predict(img, verbose=0)[0, :] Predict gender probabilities for single or multiple faces
return self.model(img, training=False).numpy()[0, :] Args:
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
np.ndarray (n, 2)
"""
# Preprocessing input image or image list.
imgs = self._preprocess_batch_or_single_input(img)
# Prediction
predictions = self._predict_internal(imgs)
return predictions
def load_model( def load_model(
url=WEIGHTS_URL, url=WEIGHTS_URL,
@ -64,7 +80,7 @@ def load_model(
# -------------------------- # --------------------------
gender_model = Model(inputs=model.input, outputs=base_model_output) gender_model = Model(inputs=model.inputs, outputs=base_model_output)
# -------------------------- # --------------------------

View File

@ -1,3 +1,6 @@
# stdlib dependencies
from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -37,10 +40,24 @@ class RaceClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Race" self.model_name = "Race"
def predict(self, img: np.ndarray) -> np.ndarray: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
# model.predict causes memory issue when it is called in a for loop """
# return self.model.predict(img, verbose=0)[0, :] Predict race probabilities for single or multiple faces
return self.model(img, training=False).numpy()[0, :] Args:
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
np.ndarray (n, n_races)
where n_races is the number of race categories
"""
# Preprocessing input image or image list.
imgs = self._preprocess_batch_or_single_input(img)
# Prediction
predictions = self._predict_internal(imgs)
return predictions
def load_model( def load_model(
@ -62,7 +79,7 @@ def load_model(
# -------------------------- # --------------------------
race_model = Model(inputs=model.input, outputs=base_model_output) race_model = Model(inputs=model.inputs, outputs=base_model_output)
# -------------------------- # --------------------------

View File

@ -100,6 +100,30 @@ def analyze(
- 'white': Confidence score for White ethnicity. - 'white': Confidence score for White ethnicity.
""" """
if isinstance(img_path, np.ndarray) and len(img_path.shape) == 4:
# Received 4-D array, which means image batch.
# Check batch dimension and process each image separately.
if img_path.shape[0] > 1:
batch_resp_obj = []
# Execute analysis for each image in the batch.
for single_img in img_path:
# Call the analyze function for each image in the batch.
resp_obj = analyze(
img_path=single_img,
actions=actions,
enforce_detection=enforce_detection,
detector_backend=detector_backend,
align=align,
expand_percentage=expand_percentage,
silent=silent,
anti_spoofing=anti_spoofing,
)
# Append the response object to the batch response list.
batch_resp_obj.append(resp_obj)
return batch_resp_obj
# if actions is passed as tuple with single item, interestingly it becomes str here # if actions is passed as tuple with single item, interestingly it becomes str here
if isinstance(actions, str): if isinstance(actions, str):
actions = (actions,) actions = (actions,)

View File

@ -1,8 +1,10 @@
# 3rd party dependencies # 3rd party dependencies
import cv2 import cv2
import numpy as np
# project dependencies # project dependencies
from deepface import DeepFace from deepface import DeepFace
from deepface.models.demography import Age, Emotion, Gender, Race
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
@ -16,6 +18,7 @@ def test_standard_analyze():
demography_objs = DeepFace.analyze(img, silent=True) demography_objs = DeepFace.analyze(img, silent=True)
for demography in demography_objs: for demography in demography_objs:
logger.debug(demography) logger.debug(demography)
assert type(demography) == dict
assert demography["age"] > 20 and demography["age"] < 40 assert demography["age"] > 20 and demography["age"] < 40
assert demography["dominant_gender"] == "Woman" assert demography["dominant_gender"] == "Woman"
logger.info("✅ test standard analyze done") logger.info("✅ test standard analyze done")
@ -29,6 +32,7 @@ def test_analyze_with_all_actions_as_tuple():
for demography in demography_objs: for demography in demography_objs:
logger.debug(f"Demography: {demography}") logger.debug(f"Demography: {demography}")
assert type(demography) == dict
age = demography["age"] age = demography["age"]
gender = demography["dominant_gender"] gender = demography["dominant_gender"]
race = demography["dominant_race"] race = demography["dominant_race"]
@ -53,6 +57,7 @@ def test_analyze_with_all_actions_as_list():
for demography in demography_objs: for demography in demography_objs:
logger.debug(f"Demography: {demography}") logger.debug(f"Demography: {demography}")
assert type(demography) == dict
age = demography["age"] age = demography["age"]
gender = demography["dominant_gender"] gender = demography["dominant_gender"]
race = demography["dominant_race"] race = demography["dominant_race"]
@ -74,6 +79,7 @@ def test_analyze_for_some_actions():
demography_objs = DeepFace.analyze(img, ["age", "gender"], silent=True) demography_objs = DeepFace.analyze(img, ["age", "gender"], silent=True)
for demography in demography_objs: for demography in demography_objs:
assert type(demography) == dict
age = demography["age"] age = demography["age"]
gender = demography["dominant_gender"] gender = demography["dominant_gender"]
@ -95,6 +101,7 @@ def test_analyze_for_preloaded_image():
resp_objs = DeepFace.analyze(img, silent=True) resp_objs = DeepFace.analyze(img, silent=True)
for resp_obj in resp_objs: for resp_obj in resp_objs:
logger.debug(resp_obj) logger.debug(resp_obj)
assert type(resp_obj) == dict
assert resp_obj["age"] > 20 and resp_obj["age"] < 40 assert resp_obj["age"] > 20 and resp_obj["age"] < 40
assert resp_obj["dominant_gender"] == "Woman" assert resp_obj["dominant_gender"] == "Woman"
@ -131,7 +138,73 @@ def test_analyze_for_different_detectors():
] ]
# validate probabilities # validate probabilities
assert type(result) == dict
if result["dominant_gender"] == "Man": if result["dominant_gender"] == "Man":
assert result["gender"]["Man"] > result["gender"]["Woman"] assert result["gender"]["Man"] > result["gender"]["Woman"]
else: else:
assert result["gender"]["Man"] < result["gender"]["Woman"] assert result["gender"]["Man"] < result["gender"]["Woman"]
def test_analyze_for_batched_image():
img = "dataset/img4.jpg"
# Copy and combine the same image to create multiple faces
img = cv2.imread(img)
img = np.stack([img, img])
assert len(img.shape) == 4 # Check dimension.
assert img.shape[0] == 2 # Check batch size.
demography_batch = DeepFace.analyze(img, silent=True)
# 2 image in batch, so 2 demography objects.
assert len(demography_batch) == 2
for demography_objs in demography_batch:
assert len(demography_objs) == 1 # 1 face in each image
for demography in demography_objs: # Iterate over faces
assert type(demography) == dict # Check type
assert demography["age"] > 20 and demography["age"] < 40
assert demography["dominant_gender"] == "Woman"
logger.info("✅ test analyze for multiple faces done")
def test_batch_detect_age_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
imgs = [img, img]
results = Age.ApparentAgeClient().predict(imgs)
# Check there are two ages detected
assert len(results) == 2
# Check two faces ages are the same in integer formate.g. 23.6 -> 23
# Must use int() to compare because of max float precision issue in different platforms
assert np.array_equal(int(results[0]), int(results[1]))
logger.info("✅ test batch detect age for multiple faces done")
def test_batch_detect_emotion_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
imgs = [img, img]
results = Emotion.EmotionClient().predict(imgs)
# Check there are two emotions detected
assert len(results) == 2
# Check two faces emotions are the same
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect emotion for multiple faces done")
def test_batch_detect_gender_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
imgs = [img, img]
results = Gender.GenderClient().predict(imgs)
# Check there are two genders detected
assert len(results) == 2
# Check two genders are the same
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect gender for multiple faces done")
def test_batch_detect_race_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
imgs = [img, img]
results = Race.RaceClient().predict(imgs)
# Check there are two races detected
assert len(results) == 2
# Check two races are the same
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect race for multiple faces done")