Merge branch 'master' of https://github.com/dakotah-jones/deepface into patch-1

This commit is contained in:
dakotah-jones 2025-02-18 16:36:10 -05:00
commit 27d82ff00b
16 changed files with 568 additions and 151 deletions

View File

@ -405,23 +405,27 @@ If you do like this work, then you can support it financially on [Patreon](https
<img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/patreon.png" width="30%" height="30%"> <img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/patreon.png" width="30%" height="30%">
</a> </a>
<a href="https://github.com/sponsors/serengil">
<img src="https://raw.githubusercontent.com/serengil/deepface/refs/heads/master/icon/github_sponsor_button.png" width="37%" height="37%">
</a>
<a href="https://buymeacoffee.com/serengil"> <a href="https://buymeacoffee.com/serengil">
<img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/bmc-button.png" width="25%" height="25%"> <img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/bmc-button.png" width="25%" height="25%">
</a> </a>
<!--
Additionally, you can help us reach a wider audience by upvoting our posts on Hacker News and Product Hunt. Additionally, you can help us reach a wider audience by upvoting our posts on Hacker News and Product Hunt.
<div style="display: flex; align-items: center; gap: 10px;"> <div style="display: flex; align-items: center; gap: 10px;">
<!-- Hacker News Badge -->
<a href="https://news.ycombinator.com/item?id=42584896"> <a href="https://news.ycombinator.com/item?id=42584896">
<img src="https://hackerbadge.vercel.app/api?id=42584896&type=orange" style="width: 250px; height: 54px;" width="250" alt="Featured on Hacker News"> <img src="https://hackerbadge.vercel.app/api?id=42584896&type=orange" style="width: 250px; height: 54px;" width="250" alt="Featured on Hacker News">
</a> </a>
<!-- Product Hunt Badge -->
<a href="https://www.producthunt.com/posts/deepface?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-deepface" target="_blank"> <a href="https://www.producthunt.com/posts/deepface?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-deepface" target="_blank">
<img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=753599&theme=light" alt="DeepFace - A Lightweight Deep Face Recognition Library for Python | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /> <img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=753599&theme=light" alt="DeepFace - A Lightweight Deep Face Recognition Library for Python | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" />
</a> </a>
</div> </div>
-->
## Citation ## Citation

View File

@ -2,7 +2,7 @@
import os import os
import warnings import warnings
import logging import logging
from typing import Any, Dict, IO, List, Union, Optional from typing import Any, Dict, IO, List, Union, Optional, Sequence
# this has to be set before importing tensorflow # this has to be set before importing tensorflow
os.environ["TF_USE_LEGACY_KERAS"] = "1" os.environ["TF_USE_LEGACY_KERAS"] = "1"
@ -174,7 +174,7 @@ def analyze(
expand_percentage: int = 0, expand_percentage: int = 0,
silent: bool = False, silent: bool = False,
anti_spoofing: bool = False, anti_spoofing: bool = False,
) -> List[Dict[str, Any]]: ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
""" """
Analyze facial attributes such as age, gender, emotion, and race in the provided image. Analyze facial attributes such as age, gender, emotion, and race in the provided image.
Args: Args:
@ -206,7 +206,10 @@ def analyze(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False). anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns: Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents (List[List[Dict[str, Any]]]): A list of analysis results if received batched image,
explained below.
(List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
the analysis results for a detected face. Each dictionary in the list contains the the analysis results for a detected face. Each dictionary in the list contains the
following keys: following keys:
@ -373,7 +376,7 @@ def find(
def represent( def represent(
img_path: Union[str, np.ndarray, IO[bytes]], img_path: Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]],
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
@ -382,15 +385,18 @@ def represent(
normalization: str = "base", normalization: str = "base",
anti_spoofing: bool = False, anti_spoofing: bool = False,
max_faces: Optional[int] = None, max_faces: Optional[int] = None,
) -> List[Dict[str, Any]]: ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
""" """
Represent facial images as multi-dimensional vector embeddings. Represent facial images as multi-dimensional vector embeddings.
Args: Args:
img_path (str or np.ndarray or IO[bytes]): The exact path to the image, a numpy array img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]):
The exact path to the image, a numpy array
in BGR format, a file object that supports at least `.read` and is opened in binary in BGR format, a file object that supports at least `.read` and is opened in binary
mode, or a base64 encoded image. If the source image contains multiple faces, mode, or a base64 encoded image. If the source image contains multiple faces,
the result will include information for each detected face. the result will include information for each detected face. If a sequence is provided,
each element should be a string or numpy array representing an image, and the function
will process images in batch.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
@ -417,8 +423,9 @@ def represent(
max_faces (int): Set a limit on the number of faces to be processed (default is None). max_faces (int): Set a limit on the number of faces to be processed (default is None).
Returns: Returns:
results (List[Dict[str, Any]]): A list of dictionaries, each containing the results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries.
following fields: Result type becomes List of List of Dict if batch input passed.
Each containing the following fields:
- embedding (List[float]): Multidimensional vector representing facial features. - embedding (List[float]): Multidimensional vector representing facial features.
The number of dimensions varies based on the reference model The number of dimensions varies based on the reference model

View File

@ -1,4 +1,4 @@
from typing import Union from typing import Union, List
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
from deepface.commons import package_utils from deepface.commons import package_utils
@ -18,5 +18,51 @@ class Demography(ABC):
model_name: str model_name: str
@abstractmethod @abstractmethod
def predict(self, img: np.ndarray) -> Union[np.ndarray, np.float64]: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.float64]:
pass pass
def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray:
"""
Predict for single image or batched images.
This method uses legacy method while receiving single image as input.
And switch to batch prediction if receives batched images.
Args:
img_batch:
Batch of images as np.ndarray (n, x, y, c)
with n >= 1, x = image width, y = image height, c = channel
Or Single image as np.ndarray (1, x, y, c)
with x = image width, y = image height and c = channel
The channel dimension will be 1 if input is grayscale. (For emotion model)
"""
if not self.model_name: # Check if called from derived class
raise NotImplementedError("no model selected")
assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
if img_batch.shape[0] == 1: # Single image
# Predict with legacy method.
return self.model(img_batch, training=False).numpy()[0, :]
# Batch of images
# Predict with batch prediction
return self.model.predict_on_batch(img_batch)
def _preprocess_batch_or_single_input(
self, img: Union[np.ndarray, List[np.ndarray]]
) -> np.ndarray:
"""
Preprocess single or batch of images, return as 4-D numpy array.
Args:
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
Four-dimensional numpy array (n, 224, 224, 3)
"""
image_batch = np.array(img)
# Check input dimension
if len(image_batch.shape) == 3:
# Single image - add batch dimension
image_batch = np.expand_dims(image_batch, axis=0)
return image_batch

View File

@ -18,7 +18,7 @@ class FacialRecognition(ABC):
input_shape: Tuple[int, int] input_shape: Tuple[int, int]
output_shape: int output_shape: int
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
if not isinstance(self.model, Model): if not isinstance(self.model, Model):
raise ValueError( raise ValueError(
"You must overwrite forward method if it is not a keras model," "You must overwrite forward method if it is not a keras model,"
@ -26,4 +26,9 @@ class FacialRecognition(ABC):
) )
# model.predict causes memory issue when it is called in a for loop # model.predict causes memory issue when it is called in a for loop
# embedding = model.predict(img, verbose=0)[0].tolist() # embedding = model.predict(img, verbose=0)[0].tolist()
return self.model(img, training=False).numpy()[0].tolist() if img.shape == 4 and img.shape[0] == 1:
img = img[0]
embeddings = self.model(img, training=False).numpy()
if embeddings.shape[0] == 1:
return embeddings[0].tolist()
return embeddings.tolist()

View File

@ -1,3 +1,7 @@
# stdlib dependencies
from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -9,7 +13,6 @@ from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
# ----------------------------------------
# dependency configurations # dependency configurations
tf_version = package_utils.get_tf_major_version() tf_version = package_utils.get_tf_major_version()
@ -21,12 +24,11 @@ else:
from tensorflow.keras.models import Model, Sequential from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Convolution2D, Flatten, Activation from tensorflow.keras.layers import Convolution2D, Flatten, Activation
# ----------------------------------------
WEIGHTS_URL = ( WEIGHTS_URL = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5" "https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5"
) )
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class ApparentAgeClient(Demography): class ApparentAgeClient(Demography):
""" """
@ -37,11 +39,28 @@ class ApparentAgeClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Age" self.model_name = "Age"
def predict(self, img: np.ndarray) -> np.float64: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, np.ndarray]:
# model.predict causes memory issue when it is called in a for loop """
# age_predictions = self.model.predict(img, verbose=0)[0, :] Predict apparent age(s) for single or multiple faces
age_predictions = self.model(img, training=False).numpy()[0, :] Args:
return find_apparent_age(age_predictions) img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
np.ndarray (age_classes,) if single image,
np.ndarray (n, age_classes) if batched images.
"""
# Preprocessing input image or image list.
imgs = self._preprocess_batch_or_single_input(img)
# Prediction from 3 channels image
age_predictions = self._predict_internal(imgs)
# Calculate apparent ages
if len(age_predictions.shape) == 1: # Single prediction list
return find_apparent_age(age_predictions)
return np.array([find_apparent_age(age_prediction) for age_prediction in age_predictions])
def load_model( def load_model(
@ -65,7 +84,7 @@ def load_model(
# -------------------------- # --------------------------
age_model = Model(inputs=model.input, outputs=base_model_output) age_model = Model(inputs=model.inputs, outputs=base_model_output)
# -------------------------- # --------------------------
@ -83,10 +102,14 @@ def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
""" """
Find apparent age prediction from a given probas of ages Find apparent age prediction from a given probas of ages
Args: Args:
age_predictions (?) age_predictions (age_classes,)
Returns: Returns:
apparent_age (float) apparent_age (float)
""" """
assert (
len(age_predictions.shape) == 1
), f"Input should be a list of predictions, \
not batched. Got shape: {age_predictions.shape}"
output_indexes = np.arange(0, 101) output_indexes = np.arange(0, 101)
apparent_age = np.sum(age_predictions * output_indexes) apparent_age = np.sum(age_predictions * output_indexes)
return apparent_age return apparent_age

View File

@ -1,3 +1,6 @@
# stdlib dependencies
from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
import cv2 import cv2
@ -43,16 +46,38 @@ class EmotionClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Emotion" self.model_name = "Emotion"
def predict(self, img: np.ndarray) -> np.ndarray: def _preprocess_image(self, img: np.ndarray) -> np.ndarray:
img_gray = cv2.cvtColor(img[0], cv2.COLOR_BGR2GRAY) """
Preprocess single image for emotion detection
Args:
img: Input image (224, 224, 3)
Returns:
Preprocessed grayscale image (48, 48)
"""
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_gray = cv2.resize(img_gray, (48, 48)) img_gray = cv2.resize(img_gray, (48, 48))
img_gray = np.expand_dims(img_gray, axis=0) return img_gray
# model.predict causes memory issue when it is called in a for loop def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
# emotion_predictions = self.model.predict(img_gray, verbose=0)[0, :] """
emotion_predictions = self.model(img_gray, training=False).numpy()[0, :] Predict emotion probabilities for single or multiple faces
Args:
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
np.ndarray (n, n_emotions)
where n_emotions is the number of emotion categories
"""
# Preprocessing input image or image list.
imgs = self._preprocess_batch_or_single_input(img)
return emotion_predictions processed_imgs = np.expand_dims(np.array([self._preprocess_image(img) for img in imgs]), axis=-1)
# Prediction
predictions = self._predict_internal(processed_imgs)
return predictions
def load_model( def load_model(

View File

@ -1,3 +1,7 @@
# stdlib dependencies
from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -37,11 +41,23 @@ class GenderClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Gender" self.model_name = "Gender"
def predict(self, img: np.ndarray) -> np.ndarray: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
# model.predict causes memory issue when it is called in a for loop """
# return self.model.predict(img, verbose=0)[0, :] Predict gender probabilities for single or multiple faces
return self.model(img, training=False).numpy()[0, :] Args:
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
np.ndarray (n, 2)
"""
# Preprocessing input image or image list.
imgs = self._preprocess_batch_or_single_input(img)
# Prediction
predictions = self._predict_internal(imgs)
return predictions
def load_model( def load_model(
url=WEIGHTS_URL, url=WEIGHTS_URL,
@ -64,7 +80,7 @@ def load_model(
# -------------------------- # --------------------------
gender_model = Model(inputs=model.input, outputs=base_model_output) gender_model = Model(inputs=model.inputs, outputs=base_model_output)
# -------------------------- # --------------------------

View File

@ -1,3 +1,6 @@
# stdlib dependencies
from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -37,10 +40,24 @@ class RaceClient(Demography):
self.model = load_model() self.model = load_model()
self.model_name = "Race" self.model_name = "Race"
def predict(self, img: np.ndarray) -> np.ndarray: def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
# model.predict causes memory issue when it is called in a for loop """
# return self.model.predict(img, verbose=0)[0, :] Predict race probabilities for single or multiple faces
return self.model(img, training=False).numpy()[0, :] Args:
img: Single image as np.ndarray (224, 224, 3) or
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
np.ndarray (n, n_races)
where n_races is the number of race categories
"""
# Preprocessing input image or image list.
imgs = self._preprocess_batch_or_single_input(img)
# Prediction
predictions = self._predict_internal(imgs)
return predictions
def load_model( def load_model(
@ -62,7 +79,7 @@ def load_model(
# -------------------------- # --------------------------
race_model = Model(inputs=model.input, outputs=base_model_output) race_model = Model(inputs=model.inputs, outputs=base_model_output)
# -------------------------- # --------------------------

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import List from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -26,24 +26,22 @@ class DlibClient(FacialRecognition):
self.input_shape = (150, 150) self.input_shape = (150, 150)
self.output_shape = 128 self.output_shape = 128
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
""" """
Find embeddings with Dlib model. Find embeddings with Dlib model.
This model necessitates the override of the forward method This model necessitates the override of the forward method
because it is not a keras model. because it is not a keras model.
Args: Args:
img (np.ndarray): pre-loaded image in BGR img (np.ndarray): pre-loaded image(s) in BGR
Returns Returns
embeddings (list): multi-dimensional vector embeddings (list of lists or list of floats): multi-dimensional vectors
""" """
# return self.model.predict(img)[0].tolist() # Handle single image case
if len(img.shape) == 3:
# extract_faces returns 4 dimensional images img = np.expand_dims(img, axis=0)
if len(img.shape) == 4:
img = img[0]
# bgr to rgb # bgr to rgb
img = img[:, :, ::-1] # bgr to rgb img = img[:, :, :, ::-1] # bgr to rgb
# img is in scale of [0, 1] but expected [0, 255] # img is in scale of [0, 1] but expected [0, 255]
if img.max() <= 1: if img.max() <= 1:
@ -51,10 +49,11 @@ class DlibClient(FacialRecognition):
img = img.astype(np.uint8) img = img.astype(np.uint8)
img_representation = self.model.model.compute_face_descriptor(img) embeddings = self.model.model.compute_face_descriptor(img)
img_representation = np.array(img_representation) embeddings = [np.array(embedding).tolist() for embedding in embeddings]
img_representation = np.expand_dims(img_representation, axis=0) if len(embeddings) == 1:
return img_representation[0].tolist() return embeddings[0]
return embeddings
class DlibResNet: class DlibResNet:

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import Any, List from typing import Any, List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -27,7 +27,7 @@ class SFaceClient(FacialRecognition):
self.input_shape = (112, 112) self.input_shape = (112, 112)
self.output_shape = 128 self.output_shape = 128
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
""" """
Find embeddings with SFace model Find embeddings with SFace model
This model necessitates the override of the forward method This model necessitates the override of the forward method
@ -37,14 +37,17 @@ class SFaceClient(FacialRecognition):
Returns Returns
embeddings (list): multi-dimensional vector embeddings (list): multi-dimensional vector
""" """
# return self.model.predict(img)[0].tolist() input_blob = (img * 255).astype(np.uint8)
# revert the image to original format and preprocess using the model embeddings = []
input_blob = (img[0] * 255).astype(np.uint8) for i in range(input_blob.shape[0]):
embedding = self.model.model.feature(input_blob[i])
embeddings.append(embedding)
embeddings = np.concatenate(embeddings, axis=0)
embeddings = self.model.model.feature(input_blob) if embeddings.shape[0] == 1:
return embeddings[0].tolist()
return embeddings[0].tolist() return embeddings.tolist()
def load_model( def load_model(

View File

@ -57,8 +57,7 @@ class VggFaceClient(FacialRecognition):
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> List[float]:
""" """
Generates embeddings using the VGG-Face model. Generates embeddings using the VGG-Face model.
This method incorporates an additional normalization layer, This method incorporates an additional normalization layer.
necessitating the override of the forward method.
Args: Args:
img (np.ndarray): pre-loaded image in BGR img (np.ndarray): pre-loaded image in BGR
@ -70,8 +69,14 @@ class VggFaceClient(FacialRecognition):
# having normalization layer in descriptor troubles for some gpu users (e.g. issue 957, 966) # having normalization layer in descriptor troubles for some gpu users (e.g. issue 957, 966)
# instead we are now calculating it with traditional way not with keras backend # instead we are now calculating it with traditional way not with keras backend
embedding = self.model(img, training=False).numpy()[0].tolist() embedding = super().forward(img)
embedding = verification.l2_normalize(embedding) if (
isinstance(embedding, list) and
isinstance(embedding[0], list)
):
embedding = verification.l2_normalize(embedding, axis=1)
else:
embedding = verification.l2_normalize(embedding)
return embedding.tolist() return embedding.tolist()

View File

@ -100,6 +100,29 @@ def analyze(
- 'white': Confidence score for White ethnicity. - 'white': Confidence score for White ethnicity.
""" """
if isinstance(img_path, np.ndarray) and len(img_path.shape) == 4:
# Received 4-D array, which means image batch.
# Check batch dimension and process each image separately.
if img_path.shape[0] > 1:
batch_resp_obj = []
# Execute analysis for each image in the batch.
for single_img in img_path:
# Call the analyze function for each image in the batch.
resp_obj = analyze(
img_path=single_img,
actions=actions,
enforce_detection=enforce_detection,
detector_backend=detector_backend,
align=align,
expand_percentage=expand_percentage,
silent=silent,
anti_spoofing=anti_spoofing,
)
# Append the response object to the batch response list.
batch_resp_obj.append(resp_obj)
return batch_resp_obj
# if actions is passed as tuple with single item, interestingly it becomes str here # if actions is passed as tuple with single item, interestingly it becomes str here
if isinstance(actions, str): if isinstance(actions, str):
actions = (actions,) actions = (actions,)

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional, Sequence, IO
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -11,7 +11,7 @@ from deepface.models.FacialRecognition import FacialRecognition
def represent( def represent(
img_path: Union[str, np.ndarray], img_path: Union[str, IO[bytes], np.ndarray, Sequence[Union[str, np.ndarray, IO[bytes]]]],
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
@ -20,14 +20,16 @@ def represent(
normalization: str = "base", normalization: str = "base",
anti_spoofing: bool = False, anti_spoofing: bool = False,
max_faces: Optional[int] = None, max_faces: Optional[int] = None,
) -> List[Dict[str, Any]]: ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
""" """
Represent facial images as multi-dimensional vector embeddings. Represent facial images as multi-dimensional vector embeddings.
Args: Args:
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]):
or a base64 encoded image. If the source image contains multiple faces, the result will The exact path to the image, a numpy array in BGR format,
include information for each detected face. a base64 encoded image, or a sequence of these.
If the source image contains multiple faces,
the result will include information for each detected face.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
@ -51,8 +53,9 @@ def represent(
max_faces (int): Set a limit on the number of faces to be processed (default is None). max_faces (int): Set a limit on the number of faces to be processed (default is None).
Returns: Returns:
results (List[Dict[str, Any]]): A list of dictionaries, each containing the results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries.
following fields: Result type becomes List of List of Dict if batch input passed.
Each containing the following fields:
- embedding (List[float]): Multidimensional vector representing facial features. - embedding (List[float]): Multidimensional vector representing facial features.
The number of dimensions varies based on the reference model The number of dimensions varies based on the reference model
@ -70,80 +73,105 @@ def represent(
task="facial_recognition", model_name=model_name task="facial_recognition", model_name=model_name
) )
# --------------------------------- # Handle list of image paths or 4D numpy array
# we have run pre-process in verification. so, this can be skipped if it is coming from verify. if isinstance(img_path, list):
target_size = model.input_shape images = img_path
if detector_backend != "skip": elif isinstance(img_path, np.ndarray) and img_path.ndim == 4:
# Images are returned in RGB format. images = [img_path[i] for i in range(img_path.shape[0])]
img_objs = detection.extract_faces( else:
img_path=img_path, images = [img_path]
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
expand_percentage=expand_percentage,
anti_spoofing=anti_spoofing,
max_faces=max_faces,
)
else: # skip
# Try load. If load error, will raise exception internal
img, _ = image_utils.load_image(img_path)
if len(img.shape) != 3: batch_images, batch_regions, batch_confidences, batch_indexes = [], [], [], []
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
# Convert to RGB format to keep compatability with `extract_faces`. for idx, single_img_path in enumerate(images):
img = img[:, :, ::-1] # we have run pre-process in verification. so, skip if it is coming from verify.
target_size = model.input_shape
if detector_backend != "skip":
# Images are returned in RGB format.
img_objs = detection.extract_faces(
img_path=single_img_path,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
expand_percentage=expand_percentage,
anti_spoofing=anti_spoofing,
max_faces=max_faces,
)
else: # skip
# Try load. If load error, will raise exception internal
img, _ = image_utils.load_image(single_img_path)
# make dummy region and confidence to keep compatibility with `extract_faces` if len(img.shape) != 3:
img_objs = [ raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
{
"face": img,
"facial_area": {"x": 0, "y": 0, "w": img.shape[0], "h": img.shape[1]},
"confidence": 0,
}
]
# ---------------------------------
if max_faces is not None and max_faces < len(img_objs): # Convert to RGB format to keep compatability with `extract_faces`.
# sort as largest facial areas come first img = img[:, :, ::-1]
img_objs = sorted(
img_objs,
key=lambda img_obj: img_obj["facial_area"]["w"] * img_obj["facial_area"]["h"],
reverse=True,
)
# discard rest of the items
img_objs = img_objs[0:max_faces]
for img_obj in img_objs: # make dummy region and confidence to keep compatibility with `extract_faces`
if anti_spoofing is True and img_obj.get("is_real", True) is False: img_objs = [
raise ValueError("Spoof detected in the given image.") {
img = img_obj["face"] "face": img,
"facial_area": {"x": 0, "y": 0, "w": img.shape[0], "h": img.shape[1]},
"confidence": 0,
}
]
# ---------------------------------
# rgb to bgr if max_faces is not None and max_faces < len(img_objs):
img = img[:, :, ::-1] # sort as largest facial areas come first
img_objs = sorted(
img_objs,
key=lambda img_obj: img_obj["facial_area"]["w"] * img_obj["facial_area"]["h"],
reverse=True,
)
# discard rest of the items
img_objs = img_objs[0:max_faces]
region = img_obj["facial_area"] for img_obj in img_objs:
confidence = img_obj["confidence"] if anti_spoofing is True and img_obj.get("is_real", True) is False:
raise ValueError("Spoof detected in the given image.")
# resize to expected shape of ml model img = img_obj["face"]
img = preprocessing.resize_image(
img=img,
# thanks to DeepId (!)
target_size=(target_size[1], target_size[0]),
)
# custom normalization # rgb to bgr
img = preprocessing.normalize_input(img=img, normalization=normalization) img = img[:, :, ::-1]
embedding = model.forward(img) region = img_obj["facial_area"]
confidence = img_obj["confidence"]
resp_objs.append( # resize to expected shape of ml model
{ img = preprocessing.resize_image(
"embedding": embedding, img=img,
"facial_area": region, # thanks to DeepId (!)
"face_confidence": confidence, target_size=(target_size[1], target_size[0]),
} )
)
return resp_objs # custom normalization
img = preprocessing.normalize_input(img=img, normalization=normalization)
batch_images.append(img)
batch_regions.append(region)
batch_confidences.append(confidence)
batch_indexes.append(idx)
# Convert list of images to a numpy array for batch processing
batch_images = np.concatenate(batch_images, axis=0)
# Forward pass through the model for the entire batch
embeddings = model.forward(batch_images)
for idx in range(0, len(images)):
resp_obj = []
for idy, batch_index in enumerate(batch_indexes):
if idx == batch_index:
resp_obj.append(
{
"embedding": embeddings if len(batch_images) == 1 else embeddings[idy],
"facial_area": batch_regions[idy],
"face_confidence": batch_confidences[idy],
}
)
resp_objs.append(resp_obj)
return resp_objs[0] if len(images) == 1 else resp_objs

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

View File

@ -1,8 +1,10 @@
# 3rd party dependencies # 3rd party dependencies
import cv2 import cv2
import numpy as np
# project dependencies # project dependencies
from deepface import DeepFace from deepface import DeepFace
from deepface.models.demography import Age, Emotion, Gender, Race
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
@ -16,6 +18,7 @@ def test_standard_analyze():
demography_objs = DeepFace.analyze(img, silent=True) demography_objs = DeepFace.analyze(img, silent=True)
for demography in demography_objs: for demography in demography_objs:
logger.debug(demography) logger.debug(demography)
assert type(demography) == dict
assert demography["age"] > 20 and demography["age"] < 40 assert demography["age"] > 20 and demography["age"] < 40
assert demography["dominant_gender"] == "Woman" assert demography["dominant_gender"] == "Woman"
logger.info("✅ test standard analyze done") logger.info("✅ test standard analyze done")
@ -29,6 +32,7 @@ def test_analyze_with_all_actions_as_tuple():
for demography in demography_objs: for demography in demography_objs:
logger.debug(f"Demography: {demography}") logger.debug(f"Demography: {demography}")
assert type(demography) == dict
age = demography["age"] age = demography["age"]
gender = demography["dominant_gender"] gender = demography["dominant_gender"]
race = demography["dominant_race"] race = demography["dominant_race"]
@ -53,6 +57,7 @@ def test_analyze_with_all_actions_as_list():
for demography in demography_objs: for demography in demography_objs:
logger.debug(f"Demography: {demography}") logger.debug(f"Demography: {demography}")
assert type(demography) == dict
age = demography["age"] age = demography["age"]
gender = demography["dominant_gender"] gender = demography["dominant_gender"]
race = demography["dominant_race"] race = demography["dominant_race"]
@ -74,6 +79,7 @@ def test_analyze_for_some_actions():
demography_objs = DeepFace.analyze(img, ["age", "gender"], silent=True) demography_objs = DeepFace.analyze(img, ["age", "gender"], silent=True)
for demography in demography_objs: for demography in demography_objs:
assert type(demography) == dict
age = demography["age"] age = demography["age"]
gender = demography["dominant_gender"] gender = demography["dominant_gender"]
@ -95,6 +101,7 @@ def test_analyze_for_preloaded_image():
resp_objs = DeepFace.analyze(img, silent=True) resp_objs = DeepFace.analyze(img, silent=True)
for resp_obj in resp_objs: for resp_obj in resp_objs:
logger.debug(resp_obj) logger.debug(resp_obj)
assert type(resp_obj) == dict
assert resp_obj["age"] > 20 and resp_obj["age"] < 40 assert resp_obj["age"] > 20 and resp_obj["age"] < 40
assert resp_obj["dominant_gender"] == "Woman" assert resp_obj["dominant_gender"] == "Woman"
@ -131,7 +138,89 @@ def test_analyze_for_different_detectors():
] ]
# validate probabilities # validate probabilities
assert type(result) == dict
if result["dominant_gender"] == "Man": if result["dominant_gender"] == "Man":
assert result["gender"]["Man"] > result["gender"]["Woman"] assert result["gender"]["Man"] > result["gender"]["Woman"]
else: else:
assert result["gender"]["Man"] < result["gender"]["Woman"] assert result["gender"]["Man"] < result["gender"]["Woman"]
def test_analyze_for_numpy_batched_image():
img1_path = "dataset/img4.jpg"
img2_path = "dataset/couple.jpg"
# Copy and combine the same image to create multiple faces
img1 = cv2.imread(img1_path)
img2 = cv2.imread(img2_path)
expected_num_faces = [1, 2]
img1 = cv2.resize(img1, (500, 500))
img2 = cv2.resize(img2, (500, 500))
img = np.stack([img1, img2])
assert len(img.shape) == 4 # Check dimension.
assert img.shape[0] == 2 # Check batch size.
demography_batch = DeepFace.analyze(img, silent=True)
# 2 image in batch, so 2 demography objects.
assert len(demography_batch) == 2
for i, demography_objs in enumerate(demography_batch):
assert len(demography_objs) == expected_num_faces[i]
for demography in demography_objs: # Iterate over faces
assert isinstance(demography, dict) # Check type
assert demography["age"] > 20 and demography["age"] < 40
assert demography["dominant_gender"] in ["Woman", "Man"]
logger.info("✅ test analyze for multiple faces done")
def test_batch_detect_age_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
imgs = [img, img]
results = Age.ApparentAgeClient().predict(imgs)
# Check there are two ages detected
assert len(results) == 2
# Check two faces ages are the same in integer formate.g. 23.6 -> 23
# Must use int() to compare because of max float precision issue in different platforms
assert np.array_equal(int(results[0]), int(results[1]))
logger.info("✅ test batch detect age for multiple faces done")
def test_batch_detect_emotion_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
imgs = [img, img]
results = Emotion.EmotionClient().predict(imgs)
# Check there are two emotions detected
assert len(results) == 2
# Check two faces emotions are the same
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect emotion for multiple faces done")
def test_batch_detect_gender_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
imgs = [img, img]
results = Gender.GenderClient().predict(imgs)
# Check there are two genders detected
assert len(results) == 2
# Check two genders are the same
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect gender for multiple faces done")
def test_batch_detect_race_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
imgs = [img, img]
results = Race.RaceClient().predict(imgs)
# Check there are two races detected
assert len(results) == 2
# Check two races are the same
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect race for multiple faces done")

View File

@ -2,6 +2,8 @@
import io import io
import cv2 import cv2
import pytest import pytest
import numpy as np
import pytest
# project dependencies # project dependencies
from deepface import DeepFace from deepface import DeepFace
@ -13,7 +15,12 @@ logger = Logger()
def test_standard_represent(): def test_standard_represent():
img_path = "dataset/img1.jpg" img_path = "dataset/img1.jpg"
embedding_objs = DeepFace.represent(img_path) embedding_objs = DeepFace.represent(img_path)
# type should be list of dict
assert isinstance(embedding_objs, list)
for embedding_obj in embedding_objs: for embedding_obj in embedding_objs:
assert isinstance(embedding_obj, dict)
embedding = embedding_obj["embedding"] embedding = embedding_obj["embedding"]
logger.debug(f"Function returned {len(embedding)} dimensional vector") logger.debug(f"Function returned {len(embedding)} dimensional vector")
assert len(embedding) == 4096 assert len(embedding) == 4096
@ -23,18 +30,18 @@ def test_standard_represent():
def test_standard_represent_with_io_object(): def test_standard_represent_with_io_object():
img_path = "dataset/img1.jpg" img_path = "dataset/img1.jpg"
default_embedding_objs = DeepFace.represent(img_path) default_embedding_objs = DeepFace.represent(img_path)
io_embedding_objs = DeepFace.represent(open(img_path, 'rb')) io_embedding_objs = DeepFace.represent(open(img_path, "rb"))
assert default_embedding_objs == io_embedding_objs assert default_embedding_objs == io_embedding_objs
# Confirm non-seekable io objects are handled properly # Confirm non-seekable io objects are handled properly
io_obj = io.BytesIO(open(img_path, 'rb').read()) io_obj = io.BytesIO(open(img_path, "rb").read())
io_obj.seek = None io_obj.seek = None
no_seek_io_embedding_objs = DeepFace.represent(io_obj) no_seek_io_embedding_objs = DeepFace.represent(io_obj)
assert default_embedding_objs == no_seek_io_embedding_objs assert default_embedding_objs == no_seek_io_embedding_objs
# Confirm non-image io objects raise exceptions # Confirm non-image io objects raise exceptions
with pytest.raises(ValueError, match='Failed to decode image'): with pytest.raises(ValueError, match="Failed to decode image"):
DeepFace.represent(io.BytesIO(open(r'../requirements.txt', 'rb').read())) DeepFace.represent(io.BytesIO(open(r"../requirements.txt", "rb").read()))
logger.info("✅ test standard represent with io object function done") logger.info("✅ test standard represent with io object function done")
@ -55,6 +62,27 @@ def test_represent_for_skipped_detector_backend_with_image_path():
logger.info("✅ test represent function for skipped detector and image path input backend done") logger.info("✅ test represent function for skipped detector and image path input backend done")
def test_represent_for_preloaded_image():
face_img = "dataset/img5.jpg"
img = cv2.imread(face_img)
img_objs = DeepFace.represent(img_path=img)
# type should be list of dict
assert isinstance(img_objs, list)
assert len(img_objs) >= 1
for img_obj in img_objs:
assert isinstance(img_obj, dict)
assert "embedding" in img_obj.keys()
assert "facial_area" in img_obj.keys()
assert isinstance(img_obj["facial_area"], dict)
assert "x" in img_obj["facial_area"].keys()
assert "y" in img_obj["facial_area"].keys()
assert "w" in img_obj["facial_area"].keys()
assert "h" in img_obj["facial_area"].keys()
assert "face_confidence" in img_obj.keys()
logger.info("✅ test represent function for skipped detector and preloaded image done")
def test_represent_for_skipped_detector_backend_with_preloaded_image(): def test_represent_for_skipped_detector_backend_with_preloaded_image():
face_img = "dataset/img5.jpg" face_img = "dataset/img5.jpg"
img = cv2.imread(face_img) img = cv2.imread(face_img)
@ -84,12 +112,6 @@ def test_max_faces():
def test_represent_detector_backend(): def test_represent_detector_backend():
"""
There shouldn't be a difference between:
- Using a detector backend provided by `represent`
- Manually calling a detector backend, then calling `represent`.
"""
# Results using a detection backend. # Results using a detection backend.
results_1 = DeepFace.represent(img_path="dataset/img1.jpg") results_1 = DeepFace.represent(img_path="dataset/img1.jpg")
assert len(results_1) == 1 assert len(results_1) == 1
@ -108,3 +130,108 @@ def test_represent_detector_backend():
embedding_2 = results_2[0]['embedding'] embedding_2 = results_2[0]['embedding']
assert embedding_1 == embedding_2 assert embedding_1 == embedding_2
logger.info("✅ test represent function for consistent output.") logger.info("✅ test represent function for consistent output.")
@pytest.mark.parametrize(
"model_name",
[
"VGG-Face",
"Facenet",
"SFace",
],
)
def test_batched_represent_for_list_input(model_name):
img_paths = [
"dataset/img1.jpg",
"dataset/img2.jpg",
"dataset/img3.jpg",
"dataset/img4.jpg",
"dataset/img5.jpg",
"dataset/couple.jpg",
]
expected_faces = [1, 1, 1, 1, 1, 2]
batched_embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name)
# type should be list of list of dict for batch input
assert isinstance(batched_embedding_objs, list)
assert len(batched_embedding_objs) == len(
img_paths
), f"Expected {len(img_paths)} embeddings, got {len(batched_embedding_objs)}"
# the last one has two faces
for idx, embedding_objs in enumerate(batched_embedding_objs):
# type should be list of list of dict for batch input
# batched_embedding_objs was list already, embedding_objs should be list of dict
assert isinstance(embedding_objs, list)
for embedding_obj in embedding_objs:
assert isinstance(embedding_obj, dict)
assert expected_faces[idx] == len(
embedding_objs
), f"{img_paths[idx]} has {expected_faces[idx]} faces, but got {len(embedding_objs)} embeddings!"
for idx, img_path in enumerate(img_paths):
single_embedding_objs = DeepFace.represent(img_path=img_path, model_name=model_name)
# type should be list of dict for single input
assert isinstance(single_embedding_objs, list)
for embedding_obj in single_embedding_objs:
assert isinstance(embedding_obj, dict)
assert len(single_embedding_objs) == len(batched_embedding_objs[idx])
for alpha, beta in zip(single_embedding_objs, batched_embedding_objs[idx]):
assert np.allclose(
alpha["embedding"], beta["embedding"], rtol=1e-2, atol=1e-2
), "Embeddings do not match within tolerance"
logger.info(f"✅ test batch represent function with string input for model {model_name} done")
@pytest.mark.parametrize(
"model_name",
[
"VGG-Face",
"Facenet",
"SFace",
],
)
def test_batched_represent_for_numpy_input(model_name):
img_paths = [
"dataset/img1.jpg",
"dataset/img2.jpg",
"dataset/img3.jpg",
"dataset/img4.jpg",
"dataset/img5.jpg",
"dataset/couple.jpg",
]
expected_faces = [1, 1, 1, 1, 1, 2]
imgs = []
for img_path in img_paths:
img = cv2.imread(img_path)
img = cv2.resize(img, (1000, 1000))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# print(img.shape)
imgs.append(img)
imgs = np.array(imgs)
assert imgs.ndim == 4 and imgs.shape[0] == len(img_paths)
batched_embedding_objs = DeepFace.represent(img_path=imgs, model_name=model_name)
# type should be list of list of dict for batch input
assert isinstance(batched_embedding_objs, list)
for idx, batched_embedding_obj in enumerate(batched_embedding_objs):
assert isinstance(batched_embedding_obj, list)
# it also has to have the expected number of faces
assert len(batched_embedding_obj) == expected_faces[idx]
for embedding_obj in batched_embedding_obj:
assert isinstance(embedding_obj, dict)
# we should have the same number of embeddings as the number of images
assert len(batched_embedding_objs) == len(img_paths)
logger.info(f"✅ test batch represent function with numpy input for model {model_name} done")