This commit is contained in:
Raghucharan16 2025-02-20 22:14:31 +05:30
commit 41ae9bbcf3
8 changed files with 218 additions and 85 deletions

View File

@ -206,9 +206,9 @@ def analyze(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns:
(List[List[Dict[str, Any]]]): A list of analysis results if received batched image,
(List[List[Dict[str, Any]]]): A list of analysis results if received batched image,
explained below.
(List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
the analysis results for a detected face. Each dictionary in the list contains the
following keys:
@ -385,12 +385,12 @@ def represent(
normalization: str = "base",
anti_spoofing: bool = False,
max_faces: Optional[int] = None,
) -> List[Dict[str, Any]]:
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""
Represent facial images as multi-dimensional vector embeddings.
Args:
img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]):
img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]):
The exact path to the image, a numpy array
in BGR format, a file object that supports at least `.read` and is opened in binary
mode, or a base64 encoded image. If the source image contains multiple faces,
@ -423,8 +423,9 @@ def represent(
max_faces (int): Set a limit on the number of faces to be processed (default is None).
Returns:
results (List[Dict[str, Any]]): A list of dictionaries, each containing the
following fields:
results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries.
Result type becomes List of List of Dict if batch input passed.
Each containing the following fields:
- embedding (List[float]): Multidimensional vector representing facial features.
The number of dimensions varies based on the reference model

View File

@ -24,7 +24,7 @@ class Demography(ABC):
def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray:
"""
Predict for single image or batched images.
This method uses legacy method while receiving single image as input.
This method uses legacy method while receiving single image as input.
And switch to batch prediction if receives batched images.
Args:
@ -35,11 +35,11 @@ class Demography(ABC):
with x = image width, y = image height and c = channel
The channel dimension will be 1 if input is grayscale. (For emotion model)
"""
if not self.model_name: # Check if called from derived class
if not self.model_name: # Check if called from derived class
raise NotImplementedError("no model selected")
assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
if img_batch.shape[0] == 1: # Single image
if img_batch.shape[0] == 1: # Single image
# Predict with legacy method.
return self.model(img_batch, training=False).numpy()[0, :]
@ -48,10 +48,8 @@ class Demography(ABC):
return self.model.predict_on_batch(img_batch)
def _preprocess_batch_or_single_input(
self,
img: Union[np.ndarray, List[np.ndarray]]
self, img: Union[np.ndarray, List[np.ndarray]]
) -> np.ndarray:
"""
Preprocess single or batch of images, return as 4-D numpy array.
Args:

View File

@ -13,7 +13,6 @@ from deepface.commons.logger import Logger
logger = Logger()
# ----------------------------------------
# dependency configurations
tf_version = package_utils.get_tf_major_version()
@ -25,12 +24,11 @@ else:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Convolution2D, Flatten, Activation
# ----------------------------------------
WEIGHTS_URL = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5"
)
# pylint: disable=too-few-public-methods
class ApparentAgeClient(Demography):
"""
@ -49,7 +47,7 @@ class ApparentAgeClient(Demography):
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
np.ndarray (age_classes,) if single image,
np.ndarray (age_classes,) if single image,
np.ndarray (n, age_classes) if batched images.
"""
# Preprocessing input image or image list.
@ -59,11 +57,10 @@ class ApparentAgeClient(Demography):
age_predictions = self._predict_internal(imgs)
# Calculate apparent ages
if len(age_predictions.shape) == 1: # Single prediction list
if len(age_predictions.shape) == 1: # Single prediction list
return find_apparent_age(age_predictions)
return np.array([
find_apparent_age(age_prediction) for age_prediction in age_predictions])
return np.array([find_apparent_age(age_prediction) for age_prediction in age_predictions])
def load_model(
@ -100,6 +97,7 @@ def load_model(
return age_model
def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
"""
Find apparent age prediction from a given probas of ages
@ -108,7 +106,9 @@ def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
Returns:
apparent_age (float)
"""
assert len(age_predictions.shape) == 1, f"Input should be a list of predictions, \
assert (
len(age_predictions.shape) == 1
), f"Input should be a list of predictions, \
not batched. Got shape: {age_predictions.shape}"
output_indexes = np.arange(0, 101)
apparent_age = np.sum(age_predictions * output_indexes)

View File

@ -123,7 +123,6 @@ def analyze(
batch_resp_obj.append(resp_obj)
return batch_resp_obj
# if actions is passed as tuple with single item, interestingly it becomes str here
if isinstance(actions, str):
actions = (actions,)

View File

@ -398,6 +398,7 @@ def __find_bulk_embeddings(
enforce_detection=enforce_detection,
align=align,
expand_percentage=expand_percentage,
color_face='bgr' # `represent` expects images in bgr format.
)
except ValueError as err:

View File

@ -20,12 +20,12 @@ def represent(
normalization: str = "base",
anti_spoofing: bool = False,
max_faces: Optional[int] = None,
) -> List[Dict[str, Any]]:
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""
Represent facial images as multi-dimensional vector embeddings.
Args:
img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]):
img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]):
The exact path to the image, a numpy array in BGR format,
a base64 encoded image, or a sequence of these.
If the source image contains multiple faces,
@ -53,8 +53,9 @@ def represent(
max_faces (int): Set a limit on the number of faces to be processed (default is None).
Returns:
results (List[Dict[str, Any]]): A list of dictionaries, each containing the
following fields:
results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries.
Result type becomes List of List of Dict if batch input passed.
Each containing the following fields:
- embedding (List[float]): Multidimensional vector representing facial features.
The number of dimensions varies based on the reference model
@ -80,16 +81,13 @@ def represent(
else:
images = [img_path]
batch_images = []
batch_regions = []
batch_confidences = []
batch_images, batch_regions, batch_confidences, batch_indexes = [], [], [], []
for single_img_path in images:
# ---------------------------------
# we have run pre-process in verification.
# so, this can be skipped if it is coming from verify.
for idx, single_img_path in enumerate(images):
# we have run pre-process in verification. so, skip if it is coming from verify.
target_size = model.input_shape
if detector_backend != "skip":
# Images are returned in RGB format.
img_objs = detection.extract_faces(
img_path=single_img_path,
detector_backend=detector_backend,
@ -107,6 +105,9 @@ def represent(
if len(img.shape) != 3:
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
# Convert to RGB format to keep compatability with `extract_faces`.
img = img[:, :, ::-1]
# make dummy region and confidence to keep compatibility with `extract_faces`
img_objs = [
{
@ -130,9 +131,10 @@ def represent(
for img_obj in img_objs:
if anti_spoofing is True and img_obj.get("is_real", True) is False:
raise ValueError("Spoof detected in the given image.")
img = img_obj["face"]
# bgr to rgb
# rgb to bgr
img = img[:, :, ::-1]
region = img_obj["facial_area"]
@ -151,22 +153,25 @@ def represent(
batch_images.append(img)
batch_regions.append(region)
batch_confidences.append(confidence)
batch_indexes.append(idx)
# Convert list of images to a numpy array for batch processing
batch_images = np.concatenate(batch_images, axis=0)
# Forward pass through the model for the entire batch
embeddings = model.forward(batch_images)
if len(batch_images) == 1:
embeddings = [embeddings]
for embedding, region, confidence in zip(embeddings, batch_regions, batch_confidences):
resp_objs.append(
{
"embedding": embedding,
"facial_area": region,
"face_confidence": confidence,
}
)
for idx in range(0, len(images)):
resp_obj = []
for idy, batch_index in enumerate(batch_indexes):
if idx == batch_index:
resp_obj.append(
{
"embedding": embeddings if len(batch_images) == 1 else embeddings[idy],
"facial_area": batch_regions[idy],
"face_confidence": batch_confidences[idy],
}
)
resp_objs.append(resp_obj)
return resp_objs
return resp_objs[0] if len(images) == 1 else resp_objs

View File

@ -144,26 +144,39 @@ def test_analyze_for_different_detectors():
else:
assert result["gender"]["Man"] < result["gender"]["Woman"]
def test_analyze_for_batched_image():
img = "dataset/img4.jpg"
def test_analyze_for_numpy_batched_image():
img1_path = "dataset/img4.jpg"
img2_path = "dataset/couple.jpg"
# Copy and combine the same image to create multiple faces
img = cv2.imread(img)
img = np.stack([img, img])
assert len(img.shape) == 4 # Check dimension.
assert img.shape[0] == 2 # Check batch size.
img1 = cv2.imread(img1_path)
img2 = cv2.imread(img2_path)
expected_num_faces = [1, 2]
img1 = cv2.resize(img1, (500, 500))
img2 = cv2.resize(img2, (500, 500))
img = np.stack([img1, img2])
assert len(img.shape) == 4 # Check dimension.
assert img.shape[0] == 2 # Check batch size.
demography_batch = DeepFace.analyze(img, silent=True)
# 2 image in batch, so 2 demography objects.
assert len(demography_batch) == 2
assert len(demography_batch) == 2
for demography_objs in demography_batch:
assert len(demography_objs) == 1 # 1 face in each image
for demography in demography_objs: # Iterate over faces
assert type(demography) == dict # Check type
for i, demography_objs in enumerate(demography_batch):
assert len(demography_objs) == expected_num_faces[i]
for demography in demography_objs: # Iterate over faces
assert isinstance(demography, dict) # Check type
assert demography["age"] > 20 and demography["age"] < 40
assert demography["dominant_gender"] == "Woman"
assert demography["dominant_gender"] in ["Woman", "Man"]
logger.info("✅ test analyze for multiple faces done")
def test_batch_detect_age_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
@ -176,6 +189,7 @@ def test_batch_detect_age_for_multiple_faces():
assert np.array_equal(int(results[0]), int(results[1]))
logger.info("✅ test batch detect age for multiple faces done")
def test_batch_detect_emotion_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
@ -187,6 +201,7 @@ def test_batch_detect_emotion_for_multiple_faces():
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect emotion for multiple faces done")
def test_batch_detect_gender_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
@ -198,6 +213,7 @@ def test_batch_detect_gender_for_multiple_faces():
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect gender for multiple faces done")
def test_batch_detect_race_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
@ -207,4 +223,4 @@ def test_batch_detect_race_for_multiple_faces():
assert len(results) == 2
# Check two races are the same
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect race for multiple faces done")
logger.info("✅ test batch detect race for multiple faces done")

View File

@ -15,7 +15,12 @@ logger = Logger()
def test_standard_represent():
img_path = "dataset/img1.jpg"
embedding_objs = DeepFace.represent(img_path)
# type should be list of dict
assert isinstance(embedding_objs, list)
for embedding_obj in embedding_objs:
assert isinstance(embedding_obj, dict)
embedding = embedding_obj["embedding"]
logger.debug(f"Function returned {len(embedding)} dimensional vector")
assert len(embedding) == 4096
@ -25,18 +30,18 @@ def test_standard_represent():
def test_standard_represent_with_io_object():
img_path = "dataset/img1.jpg"
default_embedding_objs = DeepFace.represent(img_path)
io_embedding_objs = DeepFace.represent(open(img_path, 'rb'))
io_embedding_objs = DeepFace.represent(open(img_path, "rb"))
assert default_embedding_objs == io_embedding_objs
# Confirm non-seekable io objects are handled properly
io_obj = io.BytesIO(open(img_path, 'rb').read())
io_obj = io.BytesIO(open(img_path, "rb").read())
io_obj.seek = None
no_seek_io_embedding_objs = DeepFace.represent(io_obj)
assert default_embedding_objs == no_seek_io_embedding_objs
# Confirm non-image io objects raise exceptions
with pytest.raises(ValueError, match='Failed to decode image'):
DeepFace.represent(io.BytesIO(open(r'../requirements.txt', 'rb').read()))
with pytest.raises(ValueError, match="Failed to decode image"):
DeepFace.represent(io.BytesIO(open(r"../requirements.txt", "rb").read()))
logger.info("✅ test standard represent with io object function done")
@ -57,6 +62,27 @@ def test_represent_for_skipped_detector_backend_with_image_path():
logger.info("✅ test represent function for skipped detector and image path input backend done")
def test_represent_for_preloaded_image():
face_img = "dataset/img5.jpg"
img = cv2.imread(face_img)
img_objs = DeepFace.represent(img_path=img)
# type should be list of dict
assert isinstance(img_objs, list)
assert len(img_objs) >= 1
for img_obj in img_objs:
assert isinstance(img_obj, dict)
assert "embedding" in img_obj.keys()
assert "facial_area" in img_obj.keys()
assert isinstance(img_obj["facial_area"], dict)
assert "x" in img_obj["facial_area"].keys()
assert "y" in img_obj["facial_area"].keys()
assert "w" in img_obj["facial_area"].keys()
assert "h" in img_obj["facial_area"].keys()
assert "face_confidence" in img_obj.keys()
logger.info("✅ test represent function for skipped detector and preloaded image done")
def test_represent_for_skipped_detector_backend_with_preloaded_image():
face_img = "dataset/img5.jpg"
img = cv2.imread(face_img)
@ -85,40 +111,127 @@ def test_max_faces():
assert len(results) == max_faces
@pytest.mark.parametrize("model_name", [
"VGG-Face",
"Facenet",
"SFace",
])
def test_batched_represent(model_name):
def test_represent_detector_backend():
# Results using a detection backend.
results_1 = DeepFace.represent(img_path="dataset/img1.jpg")
assert len(results_1) == 1
# Results performing face extraction first.
faces = DeepFace.extract_faces(img_path="dataset/img1.jpg", color_face='bgr')
assert len(faces) == 1
# Images sent into represent need to be in BGR format.
img = faces[0]['face']
results_2 = DeepFace.represent(img_path=img, detector_backend="skip")
assert len(results_2) == 1
# The embeddings should be the exact same for both cases.
embedding_1 = results_1[0]['embedding']
embedding_2 = results_2[0]['embedding']
assert embedding_1 == embedding_2
logger.info("✅ test represent function for consistent output.")
@pytest.mark.parametrize(
"model_name",
[
"VGG-Face",
"Facenet",
"SFace",
],
)
def test_batched_represent_for_list_input(model_name):
img_paths = [
"dataset/img1.jpg",
"dataset/img2.jpg",
"dataset/img3.jpg",
"dataset/img4.jpg",
"dataset/img5.jpg",
"dataset/couple.jpg",
]
embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name)
assert len(embedding_objs) == len(img_paths), f"Expected {len(img_paths)} embeddings, got {len(embedding_objs)}"
expected_faces = [1, 1, 1, 1, 1, 2]
if model_name == "VGG-Face":
batched_embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name)
# type should be list of list of dict for batch input
assert isinstance(batched_embedding_objs, list)
assert len(batched_embedding_objs) == len(
img_paths
), f"Expected {len(img_paths)} embeddings, got {len(batched_embedding_objs)}"
# the last one has two faces
for idx, embedding_objs in enumerate(batched_embedding_objs):
# type should be list of list of dict for batch input
# batched_embedding_objs was list already, embedding_objs should be list of dict
assert isinstance(embedding_objs, list)
for embedding_obj in embedding_objs:
embedding = embedding_obj["embedding"]
logger.debug(f"Function returned {len(embedding)} dimensional vector")
assert len(embedding) == 4096, f"Expected embedding of length 4096, got {len(embedding)}"
assert isinstance(embedding_obj, dict)
embedding_objs_one_by_one = [
embedding_obj
for img_path in img_paths
for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name)
assert expected_faces[idx] == len(
embedding_objs
), f"{img_paths[idx]} has {expected_faces[idx]} faces, but got {len(embedding_objs)} embeddings!"
for idx, img_path in enumerate(img_paths):
single_embedding_objs = DeepFace.represent(img_path=img_path, model_name=model_name)
# type should be list of dict for single input
assert isinstance(single_embedding_objs, list)
for embedding_obj in single_embedding_objs:
assert isinstance(embedding_obj, dict)
assert len(single_embedding_objs) == len(batched_embedding_objs[idx])
for alpha, beta in zip(single_embedding_objs, batched_embedding_objs[idx]):
assert np.allclose(
alpha["embedding"], beta["embedding"], rtol=1e-2, atol=1e-2
), "Embeddings do not match within tolerance"
logger.info(f"✅ test batch represent function with string input for model {model_name} done")
@pytest.mark.parametrize(
"model_name",
[
"VGG-Face",
"Facenet",
"SFace",
],
)
def test_batched_represent_for_numpy_input(model_name):
img_paths = [
"dataset/img1.jpg",
"dataset/img2.jpg",
"dataset/img3.jpg",
"dataset/img4.jpg",
"dataset/img5.jpg",
"dataset/couple.jpg",
]
for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs):
assert np.allclose(
embedding_obj_one_by_one["embedding"],
embedding_obj["embedding"],
rtol=1e-2,
atol=1e-2
), "Embeddings do not match within tolerance"
expected_faces = [1, 1, 1, 1, 1, 2]
logger.info(f"✅ test batch represent function for model {model_name} done")
imgs = []
for img_path in img_paths:
img = cv2.imread(img_path)
img = cv2.resize(img, (1000, 1000))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# print(img.shape)
imgs.append(img)
imgs = np.array(imgs)
assert imgs.ndim == 4 and imgs.shape[0] == len(img_paths)
batched_embedding_objs = DeepFace.represent(img_path=imgs, model_name=model_name)
# type should be list of list of dict for batch input
assert isinstance(batched_embedding_objs, list)
for idx, batched_embedding_obj in enumerate(batched_embedding_objs):
assert isinstance(batched_embedding_obj, list)
# it also has to have the expected number of faces
assert len(batched_embedding_obj) == expected_faces[idx]
for embedding_obj in batched_embedding_obj:
assert isinstance(embedding_obj, dict)
# we should have the same number of embeddings as the number of images
assert len(batched_embedding_objs) == len(img_paths)
logger.info(f"✅ test batch represent function with numpy input for model {model_name} done")