diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 2a6c94c..ae0fb0b 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -2,7 +2,7 @@ import os import warnings import logging -from typing import Any, Dict, IO, List, Union, Optional +from typing import Any, Dict, IO, List, Union, Optional, Sequence # this has to be set before importing tensorflow os.environ["TF_USE_LEGACY_KERAS"] = "1" @@ -376,7 +376,7 @@ def find( def represent( - img_path: Union[str, np.ndarray, IO[bytes]], + img_path: Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]], model_name: str = "VGG-Face", enforce_detection: bool = True, detector_backend: str = "opencv", @@ -390,10 +390,13 @@ def represent( Represent facial images as multi-dimensional vector embeddings. Args: - img_path (str or np.ndarray or IO[bytes]): The exact path to the image, a numpy array + img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]): + The exact path to the image, a numpy array in BGR format, a file object that supports at least `.read` and is opened in binary mode, or a base64 encoded image. If the source image contains multiple faces, - the result will include information for each detected face. + the result will include information for each detected face. If a sequence is provided, + each element should be a string or numpy array representing an image, and the function + will process images in batch. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet diff --git a/deepface/models/FacialRecognition.py b/deepface/models/FacialRecognition.py index a6ee7b5..410a033 100644 --- a/deepface/models/FacialRecognition.py +++ b/deepface/models/FacialRecognition.py @@ -18,7 +18,7 @@ class FacialRecognition(ABC): input_shape: Tuple[int, int] output_shape: int - def forward(self, img: np.ndarray) -> List[float]: + def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]: if not isinstance(self.model, Model): raise ValueError( "You must overwrite forward method if it is not a keras model," @@ -26,4 +26,9 @@ class FacialRecognition(ABC): ) # model.predict causes memory issue when it is called in a for loop # embedding = model.predict(img, verbose=0)[0].tolist() - return self.model(img, training=False).numpy()[0].tolist() + if img.shape == 4 and img.shape[0] == 1: + img = img[0] + embeddings = self.model(img, training=False).numpy() + if embeddings.shape[0] == 1: + return embeddings[0].tolist() + return embeddings.tolist() diff --git a/deepface/models/facial_recognition/Dlib.py b/deepface/models/facial_recognition/Dlib.py index 7b29dec..a2e5ca6 100644 --- a/deepface/models/facial_recognition/Dlib.py +++ b/deepface/models/facial_recognition/Dlib.py @@ -1,5 +1,5 @@ # built-in dependencies -from typing import List +from typing import List, Union # 3rd party dependencies import numpy as np @@ -26,24 +26,22 @@ class DlibClient(FacialRecognition): self.input_shape = (150, 150) self.output_shape = 128 - def forward(self, img: np.ndarray) -> List[float]: + def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]: """ Find embeddings with Dlib model. This model necessitates the override of the forward method because it is not a keras model. Args: - img (np.ndarray): pre-loaded image in BGR + img (np.ndarray): pre-loaded image(s) in BGR Returns - embeddings (list): multi-dimensional vector + embeddings (list of lists or list of floats): multi-dimensional vectors """ - # return self.model.predict(img)[0].tolist() - - # extract_faces returns 4 dimensional images - if len(img.shape) == 4: - img = img[0] + # Handle single image case + if len(img.shape) == 3: + img = np.expand_dims(img, axis=0) # bgr to rgb - img = img[:, :, ::-1] # bgr to rgb + img = img[:, :, :, ::-1] # bgr to rgb # img is in scale of [0, 1] but expected [0, 255] if img.max() <= 1: @@ -51,10 +49,11 @@ class DlibClient(FacialRecognition): img = img.astype(np.uint8) - img_representation = self.model.model.compute_face_descriptor(img) - img_representation = np.array(img_representation) - img_representation = np.expand_dims(img_representation, axis=0) - return img_representation[0].tolist() + embeddings = self.model.model.compute_face_descriptor(img) + embeddings = [np.array(embedding).tolist() for embedding in embeddings] + if len(embeddings) == 1: + return embeddings[0] + return embeddings class DlibResNet: diff --git a/deepface/models/facial_recognition/SFace.py b/deepface/models/facial_recognition/SFace.py index f6a01ca..eeebbe3 100644 --- a/deepface/models/facial_recognition/SFace.py +++ b/deepface/models/facial_recognition/SFace.py @@ -1,5 +1,5 @@ # built-in dependencies -from typing import Any, List +from typing import Any, List, Union # 3rd party dependencies import numpy as np @@ -27,7 +27,7 @@ class SFaceClient(FacialRecognition): self.input_shape = (112, 112) self.output_shape = 128 - def forward(self, img: np.ndarray) -> List[float]: + def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]: """ Find embeddings with SFace model This model necessitates the override of the forward method @@ -37,14 +37,17 @@ class SFaceClient(FacialRecognition): Returns embeddings (list): multi-dimensional vector """ - # return self.model.predict(img)[0].tolist() + input_blob = (img * 255).astype(np.uint8) - # revert the image to original format and preprocess using the model - input_blob = (img[0] * 255).astype(np.uint8) + embeddings = [] + for i in range(input_blob.shape[0]): + embedding = self.model.model.feature(input_blob[i]) + embeddings.append(embedding) + embeddings = np.concatenate(embeddings, axis=0) - embeddings = self.model.model.feature(input_blob) - - return embeddings[0].tolist() + if embeddings.shape[0] == 1: + return embeddings[0].tolist() + return embeddings.tolist() def load_model( diff --git a/deepface/models/facial_recognition/VGGFace.py b/deepface/models/facial_recognition/VGGFace.py index bfcbcad..bffd0d6 100644 --- a/deepface/models/facial_recognition/VGGFace.py +++ b/deepface/models/facial_recognition/VGGFace.py @@ -57,8 +57,7 @@ class VggFaceClient(FacialRecognition): def forward(self, img: np.ndarray) -> List[float]: """ Generates embeddings using the VGG-Face model. - This method incorporates an additional normalization layer, - necessitating the override of the forward method. + This method incorporates an additional normalization layer. Args: img (np.ndarray): pre-loaded image in BGR @@ -70,8 +69,14 @@ class VggFaceClient(FacialRecognition): # having normalization layer in descriptor troubles for some gpu users (e.g. issue 957, 966) # instead we are now calculating it with traditional way not with keras backend - embedding = self.model(img, training=False).numpy()[0].tolist() - embedding = verification.l2_normalize(embedding) + embedding = super().forward(img) + if ( + isinstance(embedding, list) and + isinstance(embedding[0], list) + ): + embedding = verification.l2_normalize(embedding, axis=1) + else: + embedding = verification.l2_normalize(embedding) return embedding.tolist() diff --git a/deepface/modules/representation.py b/deepface/modules/representation.py index d880645..56eaef2 100644 --- a/deepface/modules/representation.py +++ b/deepface/modules/representation.py @@ -1,5 +1,5 @@ # built-in dependencies -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, List, Union, Optional, Sequence, IO # 3rd party dependencies import numpy as np @@ -11,7 +11,7 @@ from deepface.models.FacialRecognition import FacialRecognition def represent( - img_path: Union[str, np.ndarray], + img_path: Union[str, IO[bytes], np.ndarray, Sequence[Union[str, np.ndarray, IO[bytes]]]], model_name: str = "VGG-Face", enforce_detection: bool = True, detector_backend: str = "opencv", @@ -25,9 +25,11 @@ def represent( Represent facial images as multi-dimensional vector embeddings. Args: - img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, - or a base64 encoded image. If the source image contains multiple faces, the result will - include information for each detected face. + img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]): + The exact path to the image, a numpy array in BGR format, + a base64 encoded image, or a sequence of these. + If the source image contains multiple faces, + the result will include information for each detected face. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet @@ -70,70 +72,95 @@ def represent( task="facial_recognition", model_name=model_name ) - # --------------------------------- - # we have run pre-process in verification. so, this can be skipped if it is coming from verify. - target_size = model.input_shape - if detector_backend != "skip": - img_objs = detection.extract_faces( - img_path=img_path, - detector_backend=detector_backend, - grayscale=False, - enforce_detection=enforce_detection, - align=align, - expand_percentage=expand_percentage, - anti_spoofing=anti_spoofing, - max_faces=max_faces, - ) - else: # skip - # Try load. If load error, will raise exception internal - img, _ = image_utils.load_image(img_path) + # Handle list of image paths or 4D numpy array + if isinstance(img_path, list): + images = img_path + elif isinstance(img_path, np.ndarray) and img_path.ndim == 4: + images = [img_path[i] for i in range(img_path.shape[0])] + else: + images = [img_path] - if len(img.shape) != 3: - raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}") + batch_images = [] + batch_regions = [] + batch_confidences = [] - # make dummy region and confidence to keep compatibility with `extract_faces` - img_objs = [ - { - "face": img, - "facial_area": {"x": 0, "y": 0, "w": img.shape[0], "h": img.shape[1]}, - "confidence": 0, - } - ] - # --------------------------------- + for single_img_path in images: + # --------------------------------- + # we have run pre-process in verification. + # so, this can be skipped if it is coming from verify. + target_size = model.input_shape + if detector_backend != "skip": + img_objs = detection.extract_faces( + img_path=single_img_path, + detector_backend=detector_backend, + grayscale=False, + enforce_detection=enforce_detection, + align=align, + expand_percentage=expand_percentage, + anti_spoofing=anti_spoofing, + max_faces=max_faces, + ) + else: # skip + # Try load. If load error, will raise exception internal + img, _ = image_utils.load_image(single_img_path) - if max_faces is not None and max_faces < len(img_objs): - # sort as largest facial areas come first - img_objs = sorted( - img_objs, - key=lambda img_obj: img_obj["facial_area"]["w"] * img_obj["facial_area"]["h"], - reverse=True, - ) - # discard rest of the items - img_objs = img_objs[0:max_faces] + if len(img.shape) != 3: + raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}") - for img_obj in img_objs: - if anti_spoofing is True and img_obj.get("is_real", True) is False: - raise ValueError("Spoof detected in the given image.") - img = img_obj["face"] + # make dummy region and confidence to keep compatibility with `extract_faces` + img_objs = [ + { + "face": img, + "facial_area": {"x": 0, "y": 0, "w": img.shape[0], "h": img.shape[1]}, + "confidence": 0, + } + ] + # --------------------------------- - # bgr to rgb - img = img[:, :, ::-1] + if max_faces is not None and max_faces < len(img_objs): + # sort as largest facial areas come first + img_objs = sorted( + img_objs, + key=lambda img_obj: img_obj["facial_area"]["w"] * img_obj["facial_area"]["h"], + reverse=True, + ) + # discard rest of the items + img_objs = img_objs[0:max_faces] - region = img_obj["facial_area"] - confidence = img_obj["confidence"] + for img_obj in img_objs: + if anti_spoofing is True and img_obj.get("is_real", True) is False: + raise ValueError("Spoof detected in the given image.") + img = img_obj["face"] - # resize to expected shape of ml model - img = preprocessing.resize_image( - img=img, - # thanks to DeepId (!) - target_size=(target_size[1], target_size[0]), - ) + # bgr to rgb + img = img[:, :, ::-1] - # custom normalization - img = preprocessing.normalize_input(img=img, normalization=normalization) + region = img_obj["facial_area"] + confidence = img_obj["confidence"] - embedding = model.forward(img) + # resize to expected shape of ml model + img = preprocessing.resize_image( + img=img, + # thanks to DeepId (!) + target_size=(target_size[1], target_size[0]), + ) + # custom normalization + img = preprocessing.normalize_input(img=img, normalization=normalization) + + batch_images.append(img) + batch_regions.append(region) + batch_confidences.append(confidence) + + # Convert list of images to a numpy array for batch processing + batch_images = np.concatenate(batch_images, axis=0) + + # Forward pass through the model for the entire batch + embeddings = model.forward(batch_images) + if len(batch_images) == 1: + embeddings = [embeddings] + + for embedding, region, confidence in zip(embeddings, batch_regions, batch_confidences): resp_objs.append( { "embedding": embedding, diff --git a/tests/test_represent.py b/tests/test_represent.py index b33def7..bc83a4e 100644 --- a/tests/test_represent.py +++ b/tests/test_represent.py @@ -2,6 +2,8 @@ import io import cv2 import pytest +import numpy as np +import pytest # project dependencies from deepface import DeepFace @@ -81,3 +83,42 @@ def test_max_faces(): max_faces = 1 results = DeepFace.represent(img_path="dataset/couple.jpg", max_faces=max_faces) assert len(results) == max_faces + + +@pytest.mark.parametrize("model_name", [ + "VGG-Face", + "Facenet", + "SFace", +]) +def test_batched_represent(model_name): + img_paths = [ + "dataset/img1.jpg", + "dataset/img2.jpg", + "dataset/img3.jpg", + "dataset/img4.jpg", + "dataset/img5.jpg", + ] + + embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name) + assert len(embedding_objs) == len(img_paths), f"Expected {len(img_paths)} embeddings, got {len(embedding_objs)}" + + if model_name == "VGG-Face": + for embedding_obj in embedding_objs: + embedding = embedding_obj["embedding"] + logger.debug(f"Function returned {len(embedding)} dimensional vector") + assert len(embedding) == 4096, f"Expected embedding of length 4096, got {len(embedding)}" + + embedding_objs_one_by_one = [ + embedding_obj + for img_path in img_paths + for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name) + ] + for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs): + assert np.allclose( + embedding_obj_one_by_one["embedding"], + embedding_obj["embedding"], + rtol=1e-2, + atol=1e-2 + ), "Embeddings do not match within tolerance" + + logger.info(f"✅ test batch represent function for model {model_name} done")