mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
Merge branch 'feat/batch-predict-age-and-gender' into feat/merge-predicts-functions
This commit is contained in:
commit
9c079e94ae
@ -1,6 +1,8 @@
|
||||
# stdlib dependencies
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
# 3rd party dependencies
|
||||
import numpy as np
|
||||
|
||||
@ -81,6 +83,29 @@ class ApparentAgeClient(Demography):
|
||||
return apparent_ages[0]
|
||||
return apparent_ages
|
||||
|
||||
def predicts(self, imgs: List[np.ndarray]) -> np.ndarray:
|
||||
"""
|
||||
Predict apparent ages of multiple faces
|
||||
Args:
|
||||
imgs (List[np.ndarray]): (n, 224, 224, 3)
|
||||
Returns:
|
||||
apparent_ages (np.ndarray): (n,)
|
||||
"""
|
||||
# Convert list to numpy array
|
||||
imgs_:np.ndarray = np.array(imgs)
|
||||
# Remove batch dimension if exists
|
||||
imgs_ = imgs_.squeeze()
|
||||
# Check if the input is a single image
|
||||
if len(imgs_.shape) == 3:
|
||||
# Add batch dimension if not exists
|
||||
imgs_ = np.expand_dims(imgs_, axis=0)
|
||||
# Batch prediction
|
||||
age_predictions = self.model.predict_on_batch(imgs_)
|
||||
apparent_ages = np.array(
|
||||
[find_apparent_age(age_prediction) for age_prediction in age_predictions]
|
||||
)
|
||||
return apparent_ages
|
||||
|
||||
|
||||
def load_model(
|
||||
url=WEIGHTS_URL,
|
||||
|
@ -1,6 +1,8 @@
|
||||
# stdlib dependencies
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
# 3rd party dependencies
|
||||
import numpy as np
|
||||
|
||||
@ -76,6 +78,24 @@ class GenderClient(Demography):
|
||||
return predictions[0]
|
||||
return predictions
|
||||
|
||||
def predicts(self, imgs: List[np.ndarray]) -> np.ndarray:
|
||||
"""
|
||||
Predict apparent ages of multiple faces
|
||||
Args:
|
||||
imgs (List[np.ndarray]): (n, 224, 224, 3)
|
||||
Returns:
|
||||
apparent_ages (np.ndarray): (n,)
|
||||
"""
|
||||
# Convert list to numpy array
|
||||
imgs_:np.ndarray = np.array(imgs)
|
||||
# Remove redundant dimensions
|
||||
imgs_ = imgs_.squeeze()
|
||||
# Check if the input is a single image
|
||||
if len(imgs_.shape) == 3:
|
||||
# Add batch dimension
|
||||
imgs_ = np.expand_dims(imgs_, axis=0)
|
||||
return self.model.predict_on_batch(imgs_)
|
||||
|
||||
|
||||
def load_model(
|
||||
url=WEIGHTS_URL,
|
||||
|
Loading…
x
Reference in New Issue
Block a user