mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 12:05:22 +00:00
Merge pull request #1396 from NatLee/feat/batch-predict-age-and-gender
Feat/batch predict age and gender
This commit is contained in:
commit
112d1892fd
@ -174,7 +174,7 @@ def analyze(
|
|||||||
expand_percentage: int = 0,
|
expand_percentage: int = 0,
|
||||||
silent: bool = False,
|
silent: bool = False,
|
||||||
anti_spoofing: bool = False,
|
anti_spoofing: bool = False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
|
||||||
"""
|
"""
|
||||||
Analyze facial attributes such as age, gender, emotion, and race in the provided image.
|
Analyze facial attributes such as age, gender, emotion, and race in the provided image.
|
||||||
Args:
|
Args:
|
||||||
@ -206,7 +206,10 @@ def analyze(
|
|||||||
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
|
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
|
(List[List[Dict[str, Any]]]): A list of analysis results if received batched image,
|
||||||
|
explained below.
|
||||||
|
|
||||||
|
(List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
|
||||||
the analysis results for a detected face. Each dictionary in the list contains the
|
the analysis results for a detected face. Each dictionary in the list contains the
|
||||||
following keys:
|
following keys:
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Union
|
from typing import Union, List
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from deepface.commons import package_utils
|
from deepface.commons import package_utils
|
||||||
@ -18,5 +18,53 @@ class Demography(ABC):
|
|||||||
model_name: str
|
model_name: str
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(self, img: np.ndarray) -> Union[np.ndarray, np.float64]:
|
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.float64]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Predict for single image or batched images.
|
||||||
|
This method uses legacy method while receiving single image as input.
|
||||||
|
And switch to batch prediction if receives batched images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_batch:
|
||||||
|
Batch of images as np.ndarray (n, x, y, c)
|
||||||
|
with n >= 1, x = image width, y = image height, c = channel
|
||||||
|
Or Single image as np.ndarray (1, x, y, c)
|
||||||
|
with x = image width, y = image height and c = channel
|
||||||
|
The channel dimension will be 1 if input is grayscale. (For emotion model)
|
||||||
|
"""
|
||||||
|
if not self.model_name: # Check if called from derived class
|
||||||
|
raise NotImplementedError("no model selected")
|
||||||
|
assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
|
||||||
|
|
||||||
|
if img_batch.shape[0] == 1: # Single image
|
||||||
|
# Predict with legacy method.
|
||||||
|
return self.model(img_batch, training=False).numpy()[0, :]
|
||||||
|
|
||||||
|
# Batch of images
|
||||||
|
# Predict with batch prediction
|
||||||
|
return self.model.predict_on_batch(img_batch)
|
||||||
|
|
||||||
|
def _preprocess_batch_or_single_input(
|
||||||
|
self,
|
||||||
|
img: Union[np.ndarray, List[np.ndarray]]
|
||||||
|
) -> np.ndarray:
|
||||||
|
|
||||||
|
"""
|
||||||
|
Preprocess single or batch of images, return as 4-D numpy array.
|
||||||
|
Args:
|
||||||
|
img: Single image as np.ndarray (224, 224, 3) or
|
||||||
|
List of images as List[np.ndarray] or
|
||||||
|
Batch of images as np.ndarray (n, 224, 224, 3)
|
||||||
|
Returns:
|
||||||
|
Four-dimensional numpy array (n, 224, 224, 3)
|
||||||
|
"""
|
||||||
|
image_batch = np.array(img)
|
||||||
|
|
||||||
|
# Check input dimension
|
||||||
|
if len(image_batch.shape) == 3:
|
||||||
|
# Single image - add batch dimension
|
||||||
|
image_batch = np.expand_dims(image_batch, axis=0)
|
||||||
|
return image_batch
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
# stdlib dependencies
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -37,12 +41,30 @@ class ApparentAgeClient(Demography):
|
|||||||
self.model = load_model()
|
self.model = load_model()
|
||||||
self.model_name = "Age"
|
self.model_name = "Age"
|
||||||
|
|
||||||
def predict(self, img: np.ndarray) -> np.float64:
|
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, np.ndarray]:
|
||||||
# model.predict causes memory issue when it is called in a for loop
|
"""
|
||||||
# age_predictions = self.model.predict(img, verbose=0)[0, :]
|
Predict apparent age(s) for single or multiple faces
|
||||||
age_predictions = self.model(img, training=False).numpy()[0, :]
|
Args:
|
||||||
|
img: Single image as np.ndarray (224, 224, 3) or
|
||||||
|
List of images as List[np.ndarray] or
|
||||||
|
Batch of images as np.ndarray (n, 224, 224, 3)
|
||||||
|
Returns:
|
||||||
|
np.ndarray (age_classes,) if single image,
|
||||||
|
np.ndarray (n, age_classes) if batched images.
|
||||||
|
"""
|
||||||
|
# Preprocessing input image or image list.
|
||||||
|
imgs = self._preprocess_batch_or_single_input(img)
|
||||||
|
|
||||||
|
# Prediction from 3 channels image
|
||||||
|
age_predictions = self._predict_internal(imgs)
|
||||||
|
|
||||||
|
# Calculate apparent ages
|
||||||
|
if len(age_predictions.shape) == 1: # Single prediction list
|
||||||
return find_apparent_age(age_predictions)
|
return find_apparent_age(age_predictions)
|
||||||
|
|
||||||
|
return np.array([
|
||||||
|
find_apparent_age(age_prediction) for age_prediction in age_predictions])
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
url=WEIGHTS_URL,
|
url=WEIGHTS_URL,
|
||||||
@ -65,7 +87,7 @@ def load_model(
|
|||||||
|
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
age_model = Model(inputs=model.input, outputs=base_model_output)
|
age_model = Model(inputs=model.inputs, outputs=base_model_output)
|
||||||
|
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
@ -78,15 +100,16 @@ def load_model(
|
|||||||
|
|
||||||
return age_model
|
return age_model
|
||||||
|
|
||||||
|
|
||||||
def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
|
def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
|
||||||
"""
|
"""
|
||||||
Find apparent age prediction from a given probas of ages
|
Find apparent age prediction from a given probas of ages
|
||||||
Args:
|
Args:
|
||||||
age_predictions (?)
|
age_predictions (age_classes,)
|
||||||
Returns:
|
Returns:
|
||||||
apparent_age (float)
|
apparent_age (float)
|
||||||
"""
|
"""
|
||||||
|
assert len(age_predictions.shape) == 1, f"Input should be a list of predictions, \
|
||||||
|
not batched. Got shape: {age_predictions.shape}"
|
||||||
output_indexes = np.arange(0, 101)
|
output_indexes = np.arange(0, 101)
|
||||||
apparent_age = np.sum(age_predictions * output_indexes)
|
apparent_age = np.sum(age_predictions * output_indexes)
|
||||||
return apparent_age
|
return apparent_age
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
# stdlib dependencies
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
@ -43,16 +46,38 @@ class EmotionClient(Demography):
|
|||||||
self.model = load_model()
|
self.model = load_model()
|
||||||
self.model_name = "Emotion"
|
self.model_name = "Emotion"
|
||||||
|
|
||||||
def predict(self, img: np.ndarray) -> np.ndarray:
|
def _preprocess_image(self, img: np.ndarray) -> np.ndarray:
|
||||||
img_gray = cv2.cvtColor(img[0], cv2.COLOR_BGR2GRAY)
|
"""
|
||||||
|
Preprocess single image for emotion detection
|
||||||
|
Args:
|
||||||
|
img: Input image (224, 224, 3)
|
||||||
|
Returns:
|
||||||
|
Preprocessed grayscale image (48, 48)
|
||||||
|
"""
|
||||||
|
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
img_gray = cv2.resize(img_gray, (48, 48))
|
img_gray = cv2.resize(img_gray, (48, 48))
|
||||||
img_gray = np.expand_dims(img_gray, axis=0)
|
return img_gray
|
||||||
|
|
||||||
# model.predict causes memory issue when it is called in a for loop
|
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
|
||||||
# emotion_predictions = self.model.predict(img_gray, verbose=0)[0, :]
|
"""
|
||||||
emotion_predictions = self.model(img_gray, training=False).numpy()[0, :]
|
Predict emotion probabilities for single or multiple faces
|
||||||
|
Args:
|
||||||
|
img: Single image as np.ndarray (224, 224, 3) or
|
||||||
|
List of images as List[np.ndarray] or
|
||||||
|
Batch of images as np.ndarray (n, 224, 224, 3)
|
||||||
|
Returns:
|
||||||
|
np.ndarray (n, n_emotions)
|
||||||
|
where n_emotions is the number of emotion categories
|
||||||
|
"""
|
||||||
|
# Preprocessing input image or image list.
|
||||||
|
imgs = self._preprocess_batch_or_single_input(img)
|
||||||
|
|
||||||
return emotion_predictions
|
processed_imgs = np.expand_dims(np.array([self._preprocess_image(img) for img in imgs]), axis=-1)
|
||||||
|
|
||||||
|
# Prediction
|
||||||
|
predictions = self._predict_internal(processed_imgs)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
# stdlib dependencies
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -37,11 +41,23 @@ class GenderClient(Demography):
|
|||||||
self.model = load_model()
|
self.model = load_model()
|
||||||
self.model_name = "Gender"
|
self.model_name = "Gender"
|
||||||
|
|
||||||
def predict(self, img: np.ndarray) -> np.ndarray:
|
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
|
||||||
# model.predict causes memory issue when it is called in a for loop
|
"""
|
||||||
# return self.model.predict(img, verbose=0)[0, :]
|
Predict gender probabilities for single or multiple faces
|
||||||
return self.model(img, training=False).numpy()[0, :]
|
Args:
|
||||||
|
img: Single image as np.ndarray (224, 224, 3) or
|
||||||
|
List of images as List[np.ndarray] or
|
||||||
|
Batch of images as np.ndarray (n, 224, 224, 3)
|
||||||
|
Returns:
|
||||||
|
np.ndarray (n, 2)
|
||||||
|
"""
|
||||||
|
# Preprocessing input image or image list.
|
||||||
|
imgs = self._preprocess_batch_or_single_input(img)
|
||||||
|
|
||||||
|
# Prediction
|
||||||
|
predictions = self._predict_internal(imgs)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
url=WEIGHTS_URL,
|
url=WEIGHTS_URL,
|
||||||
@ -64,7 +80,7 @@ def load_model(
|
|||||||
|
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
gender_model = Model(inputs=model.input, outputs=base_model_output)
|
gender_model = Model(inputs=model.inputs, outputs=base_model_output)
|
||||||
|
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
# stdlib dependencies
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -37,10 +40,24 @@ class RaceClient(Demography):
|
|||||||
self.model = load_model()
|
self.model = load_model()
|
||||||
self.model_name = "Race"
|
self.model_name = "Race"
|
||||||
|
|
||||||
def predict(self, img: np.ndarray) -> np.ndarray:
|
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
|
||||||
# model.predict causes memory issue when it is called in a for loop
|
"""
|
||||||
# return self.model.predict(img, verbose=0)[0, :]
|
Predict race probabilities for single or multiple faces
|
||||||
return self.model(img, training=False).numpy()[0, :]
|
Args:
|
||||||
|
img: Single image as np.ndarray (224, 224, 3) or
|
||||||
|
List of images as List[np.ndarray] or
|
||||||
|
Batch of images as np.ndarray (n, 224, 224, 3)
|
||||||
|
Returns:
|
||||||
|
np.ndarray (n, n_races)
|
||||||
|
where n_races is the number of race categories
|
||||||
|
"""
|
||||||
|
# Preprocessing input image or image list.
|
||||||
|
imgs = self._preprocess_batch_or_single_input(img)
|
||||||
|
|
||||||
|
# Prediction
|
||||||
|
predictions = self._predict_internal(imgs)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
@ -62,7 +79,7 @@ def load_model(
|
|||||||
|
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
race_model = Model(inputs=model.input, outputs=base_model_output)
|
race_model = Model(inputs=model.inputs, outputs=base_model_output)
|
||||||
|
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
|
@ -100,6 +100,30 @@ def analyze(
|
|||||||
- 'white': Confidence score for White ethnicity.
|
- 'white': Confidence score for White ethnicity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(img_path, np.ndarray) and len(img_path.shape) == 4:
|
||||||
|
# Received 4-D array, which means image batch.
|
||||||
|
# Check batch dimension and process each image separately.
|
||||||
|
if img_path.shape[0] > 1:
|
||||||
|
batch_resp_obj = []
|
||||||
|
# Execute analysis for each image in the batch.
|
||||||
|
for single_img in img_path:
|
||||||
|
# Call the analyze function for each image in the batch.
|
||||||
|
resp_obj = analyze(
|
||||||
|
img_path=single_img,
|
||||||
|
actions=actions,
|
||||||
|
enforce_detection=enforce_detection,
|
||||||
|
detector_backend=detector_backend,
|
||||||
|
align=align,
|
||||||
|
expand_percentage=expand_percentage,
|
||||||
|
silent=silent,
|
||||||
|
anti_spoofing=anti_spoofing,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append the response object to the batch response list.
|
||||||
|
batch_resp_obj.append(resp_obj)
|
||||||
|
return batch_resp_obj
|
||||||
|
|
||||||
|
|
||||||
# if actions is passed as tuple with single item, interestingly it becomes str here
|
# if actions is passed as tuple with single item, interestingly it becomes str here
|
||||||
if isinstance(actions, str):
|
if isinstance(actions, str):
|
||||||
actions = (actions,)
|
actions = (actions,)
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# project dependencies
|
# project dependencies
|
||||||
from deepface import DeepFace
|
from deepface import DeepFace
|
||||||
|
from deepface.models.demography import Age, Emotion, Gender, Race
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
logger = Logger()
|
logger = Logger()
|
||||||
@ -16,6 +18,7 @@ def test_standard_analyze():
|
|||||||
demography_objs = DeepFace.analyze(img, silent=True)
|
demography_objs = DeepFace.analyze(img, silent=True)
|
||||||
for demography in demography_objs:
|
for demography in demography_objs:
|
||||||
logger.debug(demography)
|
logger.debug(demography)
|
||||||
|
assert type(demography) == dict
|
||||||
assert demography["age"] > 20 and demography["age"] < 40
|
assert demography["age"] > 20 and demography["age"] < 40
|
||||||
assert demography["dominant_gender"] == "Woman"
|
assert demography["dominant_gender"] == "Woman"
|
||||||
logger.info("✅ test standard analyze done")
|
logger.info("✅ test standard analyze done")
|
||||||
@ -29,6 +32,7 @@ def test_analyze_with_all_actions_as_tuple():
|
|||||||
|
|
||||||
for demography in demography_objs:
|
for demography in demography_objs:
|
||||||
logger.debug(f"Demography: {demography}")
|
logger.debug(f"Demography: {demography}")
|
||||||
|
assert type(demography) == dict
|
||||||
age = demography["age"]
|
age = demography["age"]
|
||||||
gender = demography["dominant_gender"]
|
gender = demography["dominant_gender"]
|
||||||
race = demography["dominant_race"]
|
race = demography["dominant_race"]
|
||||||
@ -53,6 +57,7 @@ def test_analyze_with_all_actions_as_list():
|
|||||||
|
|
||||||
for demography in demography_objs:
|
for demography in demography_objs:
|
||||||
logger.debug(f"Demography: {demography}")
|
logger.debug(f"Demography: {demography}")
|
||||||
|
assert type(demography) == dict
|
||||||
age = demography["age"]
|
age = demography["age"]
|
||||||
gender = demography["dominant_gender"]
|
gender = demography["dominant_gender"]
|
||||||
race = demography["dominant_race"]
|
race = demography["dominant_race"]
|
||||||
@ -74,6 +79,7 @@ def test_analyze_for_some_actions():
|
|||||||
demography_objs = DeepFace.analyze(img, ["age", "gender"], silent=True)
|
demography_objs = DeepFace.analyze(img, ["age", "gender"], silent=True)
|
||||||
|
|
||||||
for demography in demography_objs:
|
for demography in demography_objs:
|
||||||
|
assert type(demography) == dict
|
||||||
age = demography["age"]
|
age = demography["age"]
|
||||||
gender = demography["dominant_gender"]
|
gender = demography["dominant_gender"]
|
||||||
|
|
||||||
@ -95,6 +101,7 @@ def test_analyze_for_preloaded_image():
|
|||||||
resp_objs = DeepFace.analyze(img, silent=True)
|
resp_objs = DeepFace.analyze(img, silent=True)
|
||||||
for resp_obj in resp_objs:
|
for resp_obj in resp_objs:
|
||||||
logger.debug(resp_obj)
|
logger.debug(resp_obj)
|
||||||
|
assert type(resp_obj) == dict
|
||||||
assert resp_obj["age"] > 20 and resp_obj["age"] < 40
|
assert resp_obj["age"] > 20 and resp_obj["age"] < 40
|
||||||
assert resp_obj["dominant_gender"] == "Woman"
|
assert resp_obj["dominant_gender"] == "Woman"
|
||||||
|
|
||||||
@ -131,7 +138,73 @@ def test_analyze_for_different_detectors():
|
|||||||
]
|
]
|
||||||
|
|
||||||
# validate probabilities
|
# validate probabilities
|
||||||
|
assert type(result) == dict
|
||||||
if result["dominant_gender"] == "Man":
|
if result["dominant_gender"] == "Man":
|
||||||
assert result["gender"]["Man"] > result["gender"]["Woman"]
|
assert result["gender"]["Man"] > result["gender"]["Woman"]
|
||||||
else:
|
else:
|
||||||
assert result["gender"]["Man"] < result["gender"]["Woman"]
|
assert result["gender"]["Man"] < result["gender"]["Woman"]
|
||||||
|
|
||||||
|
def test_analyze_for_batched_image():
|
||||||
|
img = "dataset/img4.jpg"
|
||||||
|
# Copy and combine the same image to create multiple faces
|
||||||
|
img = cv2.imread(img)
|
||||||
|
img = np.stack([img, img])
|
||||||
|
assert len(img.shape) == 4 # Check dimension.
|
||||||
|
assert img.shape[0] == 2 # Check batch size.
|
||||||
|
|
||||||
|
demography_batch = DeepFace.analyze(img, silent=True)
|
||||||
|
# 2 image in batch, so 2 demography objects.
|
||||||
|
assert len(demography_batch) == 2
|
||||||
|
|
||||||
|
for demography_objs in demography_batch:
|
||||||
|
assert len(demography_objs) == 1 # 1 face in each image
|
||||||
|
for demography in demography_objs: # Iterate over faces
|
||||||
|
assert type(demography) == dict # Check type
|
||||||
|
assert demography["age"] > 20 and demography["age"] < 40
|
||||||
|
assert demography["dominant_gender"] == "Woman"
|
||||||
|
logger.info("✅ test analyze for multiple faces done")
|
||||||
|
|
||||||
|
def test_batch_detect_age_for_multiple_faces():
|
||||||
|
# Load test image and resize to model input size
|
||||||
|
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
||||||
|
imgs = [img, img]
|
||||||
|
results = Age.ApparentAgeClient().predict(imgs)
|
||||||
|
# Check there are two ages detected
|
||||||
|
assert len(results) == 2
|
||||||
|
# Check two faces ages are the same in integer format(e.g. 23.6 -> 23)
|
||||||
|
# Must use int() to compare because of max float precision issue in different platforms
|
||||||
|
assert np.array_equal(int(results[0]), int(results[1]))
|
||||||
|
logger.info("✅ test batch detect age for multiple faces done")
|
||||||
|
|
||||||
|
def test_batch_detect_emotion_for_multiple_faces():
|
||||||
|
# Load test image and resize to model input size
|
||||||
|
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
||||||
|
imgs = [img, img]
|
||||||
|
results = Emotion.EmotionClient().predict(imgs)
|
||||||
|
# Check there are two emotions detected
|
||||||
|
assert len(results) == 2
|
||||||
|
# Check two faces emotions are the same
|
||||||
|
assert np.array_equal(results[0], results[1])
|
||||||
|
logger.info("✅ test batch detect emotion for multiple faces done")
|
||||||
|
|
||||||
|
def test_batch_detect_gender_for_multiple_faces():
|
||||||
|
# Load test image and resize to model input size
|
||||||
|
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
||||||
|
imgs = [img, img]
|
||||||
|
results = Gender.GenderClient().predict(imgs)
|
||||||
|
# Check there are two genders detected
|
||||||
|
assert len(results) == 2
|
||||||
|
# Check two genders are the same
|
||||||
|
assert np.array_equal(results[0], results[1])
|
||||||
|
logger.info("✅ test batch detect gender for multiple faces done")
|
||||||
|
|
||||||
|
def test_batch_detect_race_for_multiple_faces():
|
||||||
|
# Load test image and resize to model input size
|
||||||
|
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
||||||
|
imgs = [img, img]
|
||||||
|
results = Race.RaceClient().predict(imgs)
|
||||||
|
# Check there are two races detected
|
||||||
|
assert len(results) == 2
|
||||||
|
# Check two races are the same
|
||||||
|
assert np.array_equal(results[0], results[1])
|
||||||
|
logger.info("✅ test batch detect race for multiple faces done")
|
Loading…
x
Reference in New Issue
Block a user