diff --git a/deepface/models/demography/Age.py b/deepface/models/demography/Age.py index 9c7ef3c..adef9fe 100644 --- a/deepface/models/demography/Age.py +++ b/deepface/models/demography/Age.py @@ -1,4 +1,5 @@ # stdlib dependencies + from typing import List, Union # 3rd party dependencies @@ -82,6 +83,31 @@ class ApparentAgeClient(Demography): return apparent_ages + def predicts(self, imgs: List[np.ndarray]) -> np.ndarray: + """ + Predict apparent ages of multiple faces + Args: + imgs (List[np.ndarray]): (n, 224, 224, 3) + Returns: + apparent_ages (np.ndarray): (n,) + """ + # Convert list to numpy array + imgs_:np.ndarray = np.array(imgs) + # Remove batch dimension if exists + imgs_ = imgs_.squeeze() + # Check if the input is a single image + if len(imgs_.shape) == 3: + # Add batch dimension if not exists + imgs_ = np.expand_dims(imgs_, axis=0) + # Batch prediction + age_predictions = self.model.predict_on_batch(imgs_) + apparent_ages = np.array( + [find_apparent_age(age_prediction) for age_prediction in age_predictions] + ) + return apparent_ages + + + def load_model( url=WEIGHTS_URL, ) -> Model: diff --git a/deepface/models/demography/Gender.py b/deepface/models/demography/Gender.py index ac8716a..14d6780 100644 --- a/deepface/models/demography/Gender.py +++ b/deepface/models/demography/Gender.py @@ -1,4 +1,5 @@ # stdlib dependencies + from typing import List, Union # 3rd party dependencies @@ -77,6 +78,26 @@ class GenderClient(Demography): return predictions + def predicts(self, imgs: List[np.ndarray]) -> np.ndarray: + """ + Predict apparent ages of multiple faces + Args: + imgs (List[np.ndarray]): (n, 224, 224, 3) + Returns: + apparent_ages (np.ndarray): (n,) + """ + # Convert list to numpy array + imgs_:np.ndarray = np.array(imgs) + # Remove redundant dimensions + imgs_ = imgs_.squeeze() + # Check if the input is a single image + if len(imgs_.shape) == 3: + # Add batch dimension + imgs_ = np.expand_dims(imgs_, axis=0) + return self.model.predict_on_batch(imgs_) + + + def load_model( url=WEIGHTS_URL, ) -> Model: