[update] enhance predict methods to support single and batch inputs for Age and Gender models

This commit is contained in:
Nat Lee 2024-12-17 13:44:01 +08:00
parent 29c818d61e
commit 27e8fc9d5e
2 changed files with 63 additions and 38 deletions

View File

@ -1,5 +1,5 @@
# stdlib dependencies
from typing import List
from typing import List, Union
# 3rd party dependencies
import numpy as np
@ -40,33 +40,45 @@ class ApparentAgeClient(Demography):
self.model = load_model()
self.model_name = "Age"
def predict(self, img: np.ndarray) -> np.float64:
# model.predict causes memory issue when it is called in a for loop
# age_predictions = self.model.predict(img, verbose=0)[0, :]
age_predictions = self.model(img, training=False).numpy()[0, :]
return find_apparent_age(age_predictions)
def predicts(self, imgs: List[np.ndarray]) -> np.ndarray:
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, np.ndarray]:
"""
Predict apparent ages of multiple faces
Predict apparent age(s) for single or multiple faces
Args:
imgs (List[np.ndarray]): (n, 224, 224, 3)
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
apparent_ages (np.ndarray): (n,)
Single age as np.float64 or
Multiple ages as np.ndarray (n,)
"""
# Convert list to numpy array
imgs_:np.ndarray = np.array(imgs)
# Convert to numpy array if input is list
if isinstance(img, list):
imgs = np.array(img)
else:
imgs = img
# Remove batch dimension if exists
imgs_ = imgs_.squeeze()
# Check if the input is a single image
if len(imgs_.shape) == 3:
# Add batch dimension if not exists
imgs_ = np.expand_dims(imgs_, axis=0)
imgs = imgs.squeeze()
# Check input dimension
if len(imgs.shape) == 3:
# Single image - add batch dimension
imgs = np.expand_dims(imgs, axis=0)
is_single = True
else:
is_single = False
# Batch prediction
age_predictions = self.model.predict_on_batch(imgs_)
age_predictions = self.model.predict_on_batch(imgs)
# Calculate apparent ages
apparent_ages = np.array(
[find_apparent_age(age_prediction) for age_prediction in age_predictions]
)
# Return single value for single image
if is_single:
return apparent_ages[0]
return apparent_ages

View File

@ -1,5 +1,5 @@
# stdlib dependencies
from typing import List
from typing import List, Union
# 3rd party dependencies
import numpy as np
@ -40,28 +40,41 @@ class GenderClient(Demography):
self.model = load_model()
self.model_name = "Gender"
def predict(self, img: np.ndarray) -> np.ndarray:
# model.predict causes memory issue when it is called in a for loop
# return self.model.predict(img, verbose=0)[0, :]
return self.model(img, training=False).numpy()[0, :]
def predicts(self, imgs: List[np.ndarray]) -> np.ndarray:
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.ndarray]:
"""
Predict apparent ages of multiple faces
Predict gender probabilities for single or multiple faces
Args:
imgs (List[np.ndarray]): (n, 224, 224, 3)
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
apparent_ages (np.ndarray): (n,)
Single prediction as np.ndarray (2,) [female_prob, male_prob] or
Multiple predictions as np.ndarray (n, 2)
"""
# Convert list to numpy array
imgs_:np.ndarray = np.array(imgs)
# Remove redundant dimensions
imgs_ = imgs_.squeeze()
# Check if the input is a single image
if len(imgs_.shape) == 3:
# Add batch dimension
imgs_ = np.expand_dims(imgs_, axis=0)
return self.model.predict_on_batch(imgs_)
# Convert to numpy array if input is list
if isinstance(img, list):
imgs = np.array(img)
else:
imgs = img
# Remove batch dimension if exists
imgs = imgs.squeeze()
# Check input dimension
if len(imgs.shape) == 3:
# Single image - add batch dimension
imgs = np.expand_dims(imgs, axis=0)
is_single = True
else:
is_single = False
# Batch prediction
predictions = self.model.predict_on_batch(imgs)
# Return single prediction for single image
if is_single:
return predictions[0]
return predictions
def load_model(