diff --git a/deepface/models/demography/Age.py b/deepface/models/demography/Age.py index 67ab3ae..f3a21a1 100644 --- a/deepface/models/demography/Age.py +++ b/deepface/models/demography/Age.py @@ -1,3 +1,6 @@ +# stdlib dependencies +from typing import List + # 3rd party dependencies import numpy as np @@ -43,6 +46,27 @@ class ApparentAgeClient(Demography): age_predictions = self.model(img, training=False).numpy()[0, :] return find_apparent_age(age_predictions) + def predicts(self, imgs: List[np.ndarray]) -> np.ndarray: + """ + Predict apparent ages of multiple faces + Args: + imgs (List[np.ndarray]): (n, 224, 224, 3) + Returns: + apparent_ages (np.ndarray): (n,) + """ + # Convert list to numpy array + imgs_:np.ndarray = np.array(imgs) + # Remove batch dimension if exists + imgs_ = imgs_.squeeze() + # Check if the input is a single image + if len(imgs_.shape) == 3: + # Add batch dimension if not exists + imgs_ = np.expand_dims(imgs_, axis=0) + # Batch prediction + age_predictions = self.model.predict_on_batch(imgs_) + apparent_ages = np.array([find_apparent_age(age_prediction) for age_prediction in age_predictions]) + return apparent_ages + def load_model( url=WEIGHTS_URL,