Merge pull request #4 from NatLee/feat/add-multi-face-test

Feat/add multi face test
This commit is contained in:
Nat 2024-12-31 13:54:01 +08:00 committed by GitHub
commit 4cf43be49d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 106 additions and 99 deletions

View File

@ -41,7 +41,7 @@ class ApparentAgeClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Age" self.model_name = "Age"
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, np.ndarray]: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
""" """
Predict apparent age(s) for single or multiple faces Predict apparent age(s) for single or multiple faces
Args: Args:
@ -49,8 +49,7 @@ class ApparentAgeClient(Demography):
List of images as List[np.ndarray] or List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3) Batch of images as np.ndarray (n, 224, 224, 3)
Returns: Returns:
Single age as np.float64 or np.ndarray (n,)
Multiple ages as np.ndarray (n,)
""" """
# Convert to numpy array if input is list # Convert to numpy array if input is list
if isinstance(img, list): if isinstance(img, list):
@ -65,9 +64,6 @@ class ApparentAgeClient(Demography):
if len(imgs.shape) == 3: if len(imgs.shape) == 3:
# Single image - add batch dimension # Single image - add batch dimension
imgs = np.expand_dims(imgs, axis=0) imgs = np.expand_dims(imgs, axis=0)
is_single = True
else:
is_single = False
# Batch prediction # Batch prediction
age_predictions = self.model.predict_on_batch(imgs) age_predictions = self.model.predict_on_batch(imgs)
@ -77,9 +73,6 @@ class ApparentAgeClient(Demography):
[find_apparent_age(age_prediction) for age_prediction in age_predictions] [find_apparent_age(age_prediction) for age_prediction in age_predictions]
) )
# Return single value for single image
if is_single:
return apparent_ages[0]
return apparent_ages return apparent_ages

View File

@ -58,7 +58,7 @@ class EmotionClient(Demography):
img_gray = cv2.resize(img_gray, (48, 48)) img_gray = cv2.resize(img_gray, (48, 48))
return img_gray return img_gray
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.ndarray]: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
""" """
Predict emotion probabilities for single or multiple faces Predict emotion probabilities for single or multiple faces
Args: Args:
@ -66,8 +66,7 @@ class EmotionClient(Demography):
List of images as List[np.ndarray] or List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3) Batch of images as np.ndarray (n, 224, 224, 3)
Returns: Returns:
Single prediction as np.ndarray (n_emotions,) [emotion_probs] or np.ndarray (n, n_emotions)
Multiple predictions as np.ndarray (n, n_emotions)
where n_emotions is the number of emotion categories where n_emotions is the number of emotion categories
""" """
# Convert to numpy array if input is list # Convert to numpy array if input is list
@ -83,9 +82,6 @@ class EmotionClient(Demography):
if len(imgs.shape) == 3: if len(imgs.shape) == 3:
# Single image - add batch dimension # Single image - add batch dimension
imgs = np.expand_dims(imgs, axis=0) imgs = np.expand_dims(imgs, axis=0)
is_single = True
else:
is_single = False
# Preprocess each image # Preprocess each image
processed_imgs = np.array([self._preprocess_image(img) for img in imgs]) processed_imgs = np.array([self._preprocess_image(img) for img in imgs])
@ -96,13 +92,9 @@ class EmotionClient(Demography):
# Batch prediction # Batch prediction
predictions = self.model.predict_on_batch(processed_imgs) predictions = self.model.predict_on_batch(processed_imgs)
# Return single prediction for single image
if is_single:
return predictions[0]
return predictions return predictions
def load_model( def load_model(
url=WEIGHTS_URL, url=WEIGHTS_URL,
) -> Sequential: ) -> Sequential:

View File

@ -41,7 +41,7 @@ class GenderClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Gender" self.model_name = "Gender"
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.ndarray]: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
""" """
Predict gender probabilities for single or multiple faces Predict gender probabilities for single or multiple faces
Args: Args:
@ -49,8 +49,7 @@ class GenderClient(Demography):
List of images as List[np.ndarray] or List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3) Batch of images as np.ndarray (n, 224, 224, 3)
Returns: Returns:
Single prediction as np.ndarray (2,) [female_prob, male_prob] or np.ndarray (n, 2)
Multiple predictions as np.ndarray (n, 2)
""" """
# Convert to numpy array if input is list # Convert to numpy array if input is list
if isinstance(img, list): if isinstance(img, list):
@ -65,16 +64,10 @@ class GenderClient(Demography):
if len(imgs.shape) == 3: if len(imgs.shape) == 3:
# Single image - add batch dimension # Single image - add batch dimension
imgs = np.expand_dims(imgs, axis=0) imgs = np.expand_dims(imgs, axis=0)
is_single = True
else:
is_single = False
# Batch prediction # Batch prediction
predictions = self.model.predict_on_batch(imgs) predictions = self.model.predict_on_batch(imgs)
# Return single prediction for single image
if is_single:
return predictions[0]
return predictions return predictions

View File

@ -40,7 +40,7 @@ class RaceClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Race" self.model_name = "Race"
def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.ndarray]: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
""" """
Predict race probabilities for single or multiple faces Predict race probabilities for single or multiple faces
Args: Args:
@ -48,8 +48,7 @@ class RaceClient(Demography):
List of images as List[np.ndarray] or List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3) Batch of images as np.ndarray (n, 224, 224, 3)
Returns: Returns:
Single prediction as np.ndarray (n_races,) [race_probs] or np.ndarray (n, n_races)
Multiple predictions as np.ndarray (n, n_races)
where n_races is the number of race categories where n_races is the number of race categories
""" """
# Convert to numpy array if input is list # Convert to numpy array if input is list
@ -65,16 +64,10 @@ class RaceClient(Demography):
if len(imgs.shape) == 3: if len(imgs.shape) == 3:
# Single image - add batch dimension # Single image - add batch dimension
imgs = np.expand_dims(imgs, axis=0) imgs = np.expand_dims(imgs, axis=0)
is_single = True
else:
is_single = False
# Batch prediction # Batch prediction
predictions = self.model.predict_on_batch(imgs) predictions = self.model.predict_on_batch(imgs)
# Return single prediction for single image
if is_single:
return predictions[0]
return predictions return predictions

View File

@ -9,7 +9,7 @@ from tqdm import tqdm
from deepface.modules import modeling, detection, preprocessing from deepface.modules import modeling, detection, preprocessing
from deepface.models.demography import Gender, Race, Emotion from deepface.models.demography import Gender, Race, Emotion
# pylint: disable=trailing-whitespace
def analyze( def analyze(
img_path: Union[str, np.ndarray], img_path: Union[str, np.ndarray],
actions: Union[tuple, list] = ("emotion", "age", "gender", "race"), actions: Union[tuple, list] = ("emotion", "age", "gender", "race"),
@ -130,83 +130,107 @@ def analyze(
anti_spoofing=anti_spoofing, anti_spoofing=anti_spoofing,
) )
for img_obj in img_objs: # Anti-spoofing check
if anti_spoofing is True and img_obj.get("is_real", True) is False: if anti_spoofing:
raise ValueError("Spoof detected in the given image.") for img_obj in img_objs:
if img_obj.get("is_real", True) is False:
raise ValueError("Spoof detected in the given image.")
# Prepare the input for the model
valid_faces = []
face_regions = []
face_confidences = []
for img_obj in img_objs:
# Extract the face content
img_content = img_obj["face"] img_content = img_obj["face"]
img_region = img_obj["facial_area"] # Check if the face content is empty
img_confidence = img_obj["confidence"]
if img_content.shape[0] == 0 or img_content.shape[1] == 0: if img_content.shape[0] == 0 or img_content.shape[1] == 0:
continue continue
# rgb to bgr # Convert the image to RGB format from BGR
img_content = img_content[:, :, ::-1] img_content = img_content[:, :, ::-1]
# Resize the image to the target size for the model
# resize input image
img_content = preprocessing.resize_image(img=img_content, target_size=(224, 224)) img_content = preprocessing.resize_image(img=img_content, target_size=(224, 224))
obj = {} valid_faces.append(img_content)
# facial attribute analysis face_regions.append(img_obj["facial_area"])
pbar = tqdm( face_confidences.append(img_obj["confidence"])
range(0, len(actions)),
desc="Finding actions",
disable=silent if len(actions) > 1 else True,
)
for index in pbar:
action = actions[index]
pbar.set_description(f"Action: {action}")
if action == "emotion": # If no valid faces are found, return an empty list
emotion_predictions = modeling.build_model( if not valid_faces:
task="facial_attribute", model_name="Emotion" return []
).predict(img_content)
sum_of_predictions = emotion_predictions.sum()
obj["emotion"] = {} # Convert the list of valid faces to a numpy array
faces_array = np.array(valid_faces)
resp_objects = [{} for _ in range(len(valid_faces))]
# For each action, predict the corresponding attribute
pbar = tqdm(
range(0, len(actions)),
desc="Finding actions",
disable=silent if len(actions) > 1 else True,
)
for index in pbar:
action = actions[index]
pbar.set_description(f"Action: {action}")
if action == "emotion":
# Build the emotion model
model = modeling.build_model(task="facial_attribute", model_name="Emotion")
emotion_predictions = model.predict(faces_array)
for idx, predictions in enumerate(emotion_predictions):
sum_of_predictions = predictions.sum()
resp_objects[idx]["emotion"] = {}
for i, emotion_label in enumerate(Emotion.labels): for i, emotion_label in enumerate(Emotion.labels):
emotion_prediction = 100 * emotion_predictions[i] / sum_of_predictions emotion_prediction = 100 * predictions[i] / sum_of_predictions
obj["emotion"][emotion_label] = emotion_prediction resp_objects[idx]["emotion"][emotion_label] = emotion_prediction
resp_objects[idx]["dominant_emotion"] = Emotion.labels[np.argmax(predictions)]
obj["dominant_emotion"] = Emotion.labels[np.argmax(emotion_predictions)] elif action == "age":
# Build the age model
model = modeling.build_model(task="facial_attribute", model_name="Age")
age_predictions = model.predict(faces_array)
for idx, age in enumerate(age_predictions):
resp_objects[idx]["age"] = int(age)
elif action == "age": elif action == "gender":
apparent_age = modeling.build_model( # Build the gender model
task="facial_attribute", model_name="Age" model = modeling.build_model(task="facial_attribute", model_name="Gender")
).predict(img_content) gender_predictions = model.predict(faces_array)
# int cast is for exception - object of type 'float32' is not JSON serializable
obj["age"] = int(apparent_age) for idx, predictions in enumerate(gender_predictions):
resp_objects[idx]["gender"] = {}
elif action == "gender":
gender_predictions = modeling.build_model(
task="facial_attribute", model_name="Gender"
).predict(img_content)
obj["gender"] = {}
for i, gender_label in enumerate(Gender.labels): for i, gender_label in enumerate(Gender.labels):
gender_prediction = 100 * gender_predictions[i] gender_prediction = 100 * predictions[i]
obj["gender"][gender_label] = gender_prediction resp_objects[idx]["gender"][gender_label] = gender_prediction
resp_objects[idx]["dominant_gender"] = Gender.labels[np.argmax(predictions)]
obj["dominant_gender"] = Gender.labels[np.argmax(gender_predictions)] elif action == "race":
# Build the race model
elif action == "race": model = modeling.build_model(task="facial_attribute", model_name="Race")
race_predictions = modeling.build_model( race_predictions = model.predict(faces_array)
task="facial_attribute", model_name="Race"
).predict(img_content) for idx, predictions in enumerate(race_predictions):
sum_of_predictions = race_predictions.sum() sum_of_predictions = predictions.sum()
resp_objects[idx]["race"] = {}
obj["race"] = {}
for i, race_label in enumerate(Race.labels): for i, race_label in enumerate(Race.labels):
race_prediction = 100 * race_predictions[i] / sum_of_predictions race_prediction = 100 * predictions[i] / sum_of_predictions
obj["race"][race_label] = race_prediction resp_objects[idx]["race"][race_label] = race_prediction
resp_objects[idx]["dominant_race"] = Race.labels[np.argmax(predictions)]
obj["dominant_race"] = Race.labels[np.argmax(race_predictions)] # Add the face region and confidence to the response objects
for idx, resp_obj in enumerate(resp_objects):
# ----------------------------- resp_obj["region"] = face_regions[idx]
# mention facial areas resp_obj["face_confidence"] = face_confidences[idx]
obj["region"] = img_region
# include image confidence
obj["face_confidence"] = img_confidence
resp_objects.append(obj)
return resp_objects return resp_objects

View File

@ -135,3 +135,15 @@ def test_analyze_for_different_detectors():
assert result["gender"]["Man"] > result["gender"]["Woman"] assert result["gender"]["Man"] > result["gender"]["Woman"]
else: else:
assert result["gender"]["Man"] < result["gender"]["Woman"] assert result["gender"]["Man"] < result["gender"]["Woman"]
def test_analyze_for_multiple_faces():
img = "dataset/img4.jpg"
# Copy and combine the same image to create multiple faces
img = cv2.imread(img)
img = cv2.hconcat([img, img])
demography_objs = DeepFace.analyze(img, silent=True)
for demography in demography_objs:
logger.debug(demography)
assert demography["age"] > 20 and demography["age"] < 40
assert demography["dominant_gender"] == "Woman"
logger.info("✅ test analyze for multiple faces done")