diff --git a/deepface/models/demography/Age.py b/deepface/models/demography/Age.py index 9c7ef3c..5a8fe77 100644 --- a/deepface/models/demography/Age.py +++ b/deepface/models/demography/Age.py @@ -1,6 +1,8 @@ # stdlib dependencies + from typing import List, Union + # 3rd party dependencies import numpy as np @@ -81,6 +83,29 @@ class ApparentAgeClient(Demography): return apparent_ages[0] return apparent_ages + def predicts(self, imgs: List[np.ndarray]) -> np.ndarray: + """ + Predict apparent ages of multiple faces + Args: + imgs (List[np.ndarray]): (n, 224, 224, 3) + Returns: + apparent_ages (np.ndarray): (n,) + """ + # Convert list to numpy array + imgs_:np.ndarray = np.array(imgs) + # Remove batch dimension if exists + imgs_ = imgs_.squeeze() + # Check if the input is a single image + if len(imgs_.shape) == 3: + # Add batch dimension if not exists + imgs_ = np.expand_dims(imgs_, axis=0) + # Batch prediction + age_predictions = self.model.predict_on_batch(imgs_) + apparent_ages = np.array( + [find_apparent_age(age_prediction) for age_prediction in age_predictions] + ) + return apparent_ages + def load_model( url=WEIGHTS_URL, diff --git a/deepface/models/demography/Gender.py b/deepface/models/demography/Gender.py index ac8716a..80e0896 100644 --- a/deepface/models/demography/Gender.py +++ b/deepface/models/demography/Gender.py @@ -1,6 +1,8 @@ # stdlib dependencies + from typing import List, Union + # 3rd party dependencies import numpy as np @@ -76,6 +78,24 @@ class GenderClient(Demography): return predictions[0] return predictions + def predicts(self, imgs: List[np.ndarray]) -> np.ndarray: + """ + Predict apparent ages of multiple faces + Args: + imgs (List[np.ndarray]): (n, 224, 224, 3) + Returns: + apparent_ages (np.ndarray): (n,) + """ + # Convert list to numpy array + imgs_:np.ndarray = np.array(imgs) + # Remove redundant dimensions + imgs_ = imgs_.squeeze() + # Check if the input is a single image + if len(imgs_.shape) == 3: + # Add batch dimension + imgs_ = np.expand_dims(imgs_, axis=0) + return self.model.predict_on_batch(imgs_) + def load_model( url=WEIGHTS_URL,