mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
Patch: Make Age model capable to handle single or batched input.
This commit is contained in:
parent
c72b47484d
commit
7e719dfdeb
@ -41,7 +41,7 @@ class ApparentAgeClient(Demography):
|
||||
self.model = load_model()
|
||||
self.model_name = "Age"
|
||||
|
||||
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
|
||||
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, np.ndarray]:
|
||||
"""
|
||||
Predict apparent age(s) for single or multiple faces
|
||||
Args:
|
||||
@ -49,7 +49,7 @@ class ApparentAgeClient(Demography):
|
||||
List of images as List[np.ndarray] or
|
||||
Batch of images as np.ndarray (n, 224, 224, 3)
|
||||
Returns:
|
||||
np.ndarray (n,)
|
||||
np.ndarray (age_classes,) if single image, np.ndarray (n, age_classes) if batched images.
|
||||
"""
|
||||
# Preprocessing input image or image list.
|
||||
imgs = self._preprocess_batch_or_single_input(img)
|
||||
@ -58,11 +58,11 @@ class ApparentAgeClient(Demography):
|
||||
age_predictions = self._predict_internal(imgs)
|
||||
|
||||
# Calculate apparent ages
|
||||
apparent_ages = np.array(
|
||||
[find_apparent_age(age_prediction) for age_prediction in age_predictions]
|
||||
)
|
||||
if len(age_predictions.shape) == 1: # Single prediction list
|
||||
return find_apparent_age(age_predictions)
|
||||
else: # Batched predictions
|
||||
return np.array([find_apparent_age(age_prediction) for age_prediction in age_predictions])
|
||||
|
||||
return apparent_ages
|
||||
|
||||
def load_model(
|
||||
url=WEIGHTS_URL,
|
||||
@ -98,15 +98,16 @@ def load_model(
|
||||
|
||||
return age_model
|
||||
|
||||
|
||||
def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
|
||||
"""
|
||||
Find apparent age prediction from a given probas of ages
|
||||
Args:
|
||||
age_predictions (?)
|
||||
age_predictions (age_classes,)
|
||||
Returns:
|
||||
apparent_age (float)
|
||||
"""
|
||||
assert len(age_predictions.shape) == 1, "Input should be a list of age predictions, \
|
||||
not batched. Got shape: {}".format(age_predictions.shape)
|
||||
output_indexes = np.arange(0, 101)
|
||||
apparent_age = np.sum(age_predictions * output_indexes)
|
||||
return apparent_age
|
||||
|
Loading…
x
Reference in New Issue
Block a user