diff --git a/deepface/models/demography/Age.py b/deepface/models/demography/Age.py index ae53487..e449aca 100644 --- a/deepface/models/demography/Age.py +++ b/deepface/models/demography/Age.py @@ -41,7 +41,7 @@ class ApparentAgeClient(Demography): self.model = load_model() self.model_name = "Age" - def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, np.ndarray]: """ Predict apparent age(s) for single or multiple faces Args: @@ -49,7 +49,7 @@ class ApparentAgeClient(Demography): List of images as List[np.ndarray] or Batch of images as np.ndarray (n, 224, 224, 3) Returns: - np.ndarray (n,) + np.ndarray (age_classes,) if single image, np.ndarray (n, age_classes) if batched images. """ # Preprocessing input image or image list. imgs = self._preprocess_batch_or_single_input(img) @@ -58,11 +58,11 @@ class ApparentAgeClient(Demography): age_predictions = self._predict_internal(imgs) # Calculate apparent ages - apparent_ages = np.array( - [find_apparent_age(age_prediction) for age_prediction in age_predictions] - ) + if len(age_predictions.shape) == 1: # Single prediction list + return find_apparent_age(age_predictions) + else: # Batched predictions + return np.array([find_apparent_age(age_prediction) for age_prediction in age_predictions]) - return apparent_ages def load_model( url=WEIGHTS_URL, @@ -98,15 +98,16 @@ def load_model( return age_model - def find_apparent_age(age_predictions: np.ndarray) -> np.float64: """ Find apparent age prediction from a given probas of ages Args: - age_predictions (?) + age_predictions (age_classes,) Returns: apparent_age (float) """ + assert len(age_predictions.shape) == 1, "Input should be a list of age predictions, \ + not batched. Got shape: {}".format(age_predictions.shape) output_indexes = np.arange(0, 101) apparent_age = np.sum(age_predictions * output_indexes) return apparent_age