Merge branch 'feat/merge-predicts-functions' into feat/make-Race-and-Emotion-batch

This commit is contained in:
halice 2024-12-31 12:17:26 +08:00 committed by GitHub
commit bba4322bfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 0 deletions

View File

@ -1,4 +1,5 @@
# stdlib dependencies # stdlib dependencies
from typing import List, Union from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
@ -82,6 +83,31 @@ class ApparentAgeClient(Demography):
return apparent_ages return apparent_ages
def predicts(self, imgs: List[np.ndarray]) -> np.ndarray:
"""
Predict apparent ages of multiple faces
Args:
imgs (List[np.ndarray]): (n, 224, 224, 3)
Returns:
apparent_ages (np.ndarray): (n,)
"""
# Convert list to numpy array
imgs_:np.ndarray = np.array(imgs)
# Remove batch dimension if exists
imgs_ = imgs_.squeeze()
# Check if the input is a single image
if len(imgs_.shape) == 3:
# Add batch dimension if not exists
imgs_ = np.expand_dims(imgs_, axis=0)
# Batch prediction
age_predictions = self.model.predict_on_batch(imgs_)
apparent_ages = np.array(
[find_apparent_age(age_prediction) for age_prediction in age_predictions]
)
return apparent_ages
def load_model( def load_model(
url=WEIGHTS_URL, url=WEIGHTS_URL,
) -> Model: ) -> Model:

View File

@ -1,4 +1,5 @@
# stdlib dependencies # stdlib dependencies
from typing import List, Union from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
@ -77,6 +78,26 @@ class GenderClient(Demography):
return predictions return predictions
def predicts(self, imgs: List[np.ndarray]) -> np.ndarray:
"""
Predict apparent ages of multiple faces
Args:
imgs (List[np.ndarray]): (n, 224, 224, 3)
Returns:
apparent_ages (np.ndarray): (n,)
"""
# Convert list to numpy array
imgs_:np.ndarray = np.array(imgs)
# Remove redundant dimensions
imgs_ = imgs_.squeeze()
# Check if the input is a single image
if len(imgs_.shape) == 3:
# Add batch dimension
imgs_ = np.expand_dims(imgs_, axis=0)
return self.model.predict_on_batch(imgs_)
def load_model( def load_model(
url=WEIGHTS_URL, url=WEIGHTS_URL,
) -> Model: ) -> Model: