[update] add batch predicting for Gender model

This commit is contained in:
Nat Lee 2024-12-05 17:55:17 +08:00
parent c42df046f1
commit a4b1b5d157

View File

@ -1,3 +1,6 @@
# stdlib dependencies
from typing import List
# 3rd party dependencies
import numpy as np
@ -42,6 +45,24 @@ class GenderClient(Demography):
# return self.model.predict(img, verbose=0)[0, :]
return self.model(img, training=False).numpy()[0, :]
def predicts(self, imgs: List[np.ndarray]) -> np.ndarray:
"""
Predict apparent ages of multiple faces
Args:
imgs (List[np.ndarray]): (n, 224, 224, 3)
Returns:
apparent_ages (np.ndarray): (n,)
"""
# Convert list to numpy array
imgs_:np.ndarray = np.array(imgs)
# Remove redundant dimensions
imgs_ = imgs_.squeeze()
# Check if the input is a single image
if len(imgs_.shape) == 3:
# Add batch dimension
imgs_ = np.expand_dims(imgs_, axis=0)
return self.model.predict_on_batch(imgs_)
def load_model(
url=WEIGHTS_URL,