Merge pull request #1433 from galthran-wq/batching

Batching on `.represent` to improve performance and utilize GPU in full
This commit is contained in:
Sefik Ilkin Serengil 2025-02-16 19:58:10 +00:00 committed by GitHub
commit ca73032969
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 174 additions and 91 deletions

View File

@ -2,7 +2,7 @@
import os import os
import warnings import warnings
import logging import logging
from typing import Any, Dict, IO, List, Union, Optional from typing import Any, Dict, IO, List, Union, Optional, Sequence
# this has to be set before importing tensorflow # this has to be set before importing tensorflow
os.environ["TF_USE_LEGACY_KERAS"] = "1" os.environ["TF_USE_LEGACY_KERAS"] = "1"
@ -376,7 +376,7 @@ def find(
def represent( def represent(
img_path: Union[str, np.ndarray, IO[bytes]], img_path: Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]],
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
@ -390,10 +390,13 @@ def represent(
Represent facial images as multi-dimensional vector embeddings. Represent facial images as multi-dimensional vector embeddings.
Args: Args:
img_path (str or np.ndarray or IO[bytes]): The exact path to the image, a numpy array img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]):
The exact path to the image, a numpy array
in BGR format, a file object that supports at least `.read` and is opened in binary in BGR format, a file object that supports at least `.read` and is opened in binary
mode, or a base64 encoded image. If the source image contains multiple faces, mode, or a base64 encoded image. If the source image contains multiple faces,
the result will include information for each detected face. the result will include information for each detected face. If a sequence is provided,
each element should be a string or numpy array representing an image, and the function
will process images in batch.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet

View File

@ -18,7 +18,7 @@ class FacialRecognition(ABC):
input_shape: Tuple[int, int] input_shape: Tuple[int, int]
output_shape: int output_shape: int
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
if not isinstance(self.model, Model): if not isinstance(self.model, Model):
raise ValueError( raise ValueError(
"You must overwrite forward method if it is not a keras model," "You must overwrite forward method if it is not a keras model,"
@ -26,4 +26,9 @@ class FacialRecognition(ABC):
) )
# model.predict causes memory issue when it is called in a for loop # model.predict causes memory issue when it is called in a for loop
# embedding = model.predict(img, verbose=0)[0].tolist() # embedding = model.predict(img, verbose=0)[0].tolist()
return self.model(img, training=False).numpy()[0].tolist() if img.shape == 4 and img.shape[0] == 1:
img = img[0]
embeddings = self.model(img, training=False).numpy()
if embeddings.shape[0] == 1:
return embeddings[0].tolist()
return embeddings.tolist()

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import List from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -26,24 +26,22 @@ class DlibClient(FacialRecognition):
self.input_shape = (150, 150) self.input_shape = (150, 150)
self.output_shape = 128 self.output_shape = 128
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
""" """
Find embeddings with Dlib model. Find embeddings with Dlib model.
This model necessitates the override of the forward method This model necessitates the override of the forward method
because it is not a keras model. because it is not a keras model.
Args: Args:
img (np.ndarray): pre-loaded image in BGR img (np.ndarray): pre-loaded image(s) in BGR
Returns Returns
embeddings (list): multi-dimensional vector embeddings (list of lists or list of floats): multi-dimensional vectors
""" """
# return self.model.predict(img)[0].tolist() # Handle single image case
if len(img.shape) == 3:
# extract_faces returns 4 dimensional images img = np.expand_dims(img, axis=0)
if len(img.shape) == 4:
img = img[0]
# bgr to rgb # bgr to rgb
img = img[:, :, ::-1] # bgr to rgb img = img[:, :, :, ::-1] # bgr to rgb
# img is in scale of [0, 1] but expected [0, 255] # img is in scale of [0, 1] but expected [0, 255]
if img.max() <= 1: if img.max() <= 1:
@ -51,10 +49,11 @@ class DlibClient(FacialRecognition):
img = img.astype(np.uint8) img = img.astype(np.uint8)
img_representation = self.model.model.compute_face_descriptor(img) embeddings = self.model.model.compute_face_descriptor(img)
img_representation = np.array(img_representation) embeddings = [np.array(embedding).tolist() for embedding in embeddings]
img_representation = np.expand_dims(img_representation, axis=0) if len(embeddings) == 1:
return img_representation[0].tolist() return embeddings[0]
return embeddings
class DlibResNet: class DlibResNet:

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import Any, List from typing import Any, List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -27,7 +27,7 @@ class SFaceClient(FacialRecognition):
self.input_shape = (112, 112) self.input_shape = (112, 112)
self.output_shape = 128 self.output_shape = 128
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
""" """
Find embeddings with SFace model Find embeddings with SFace model
This model necessitates the override of the forward method This model necessitates the override of the forward method
@ -37,14 +37,17 @@ class SFaceClient(FacialRecognition):
Returns Returns
embeddings (list): multi-dimensional vector embeddings (list): multi-dimensional vector
""" """
# return self.model.predict(img)[0].tolist() input_blob = (img * 255).astype(np.uint8)
# revert the image to original format and preprocess using the model embeddings = []
input_blob = (img[0] * 255).astype(np.uint8) for i in range(input_blob.shape[0]):
embedding = self.model.model.feature(input_blob[i])
embeddings = self.model.model.feature(input_blob) embeddings.append(embedding)
embeddings = np.concatenate(embeddings, axis=0)
if embeddings.shape[0] == 1:
return embeddings[0].tolist() return embeddings[0].tolist()
return embeddings.tolist()
def load_model( def load_model(

View File

@ -57,8 +57,7 @@ class VggFaceClient(FacialRecognition):
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> List[float]:
""" """
Generates embeddings using the VGG-Face model. Generates embeddings using the VGG-Face model.
This method incorporates an additional normalization layer, This method incorporates an additional normalization layer.
necessitating the override of the forward method.
Args: Args:
img (np.ndarray): pre-loaded image in BGR img (np.ndarray): pre-loaded image in BGR
@ -70,7 +69,13 @@ class VggFaceClient(FacialRecognition):
# having normalization layer in descriptor troubles for some gpu users (e.g. issue 957, 966) # having normalization layer in descriptor troubles for some gpu users (e.g. issue 957, 966)
# instead we are now calculating it with traditional way not with keras backend # instead we are now calculating it with traditional way not with keras backend
embedding = self.model(img, training=False).numpy()[0].tolist() embedding = super().forward(img)
if (
isinstance(embedding, list) and
isinstance(embedding[0], list)
):
embedding = verification.l2_normalize(embedding, axis=1)
else:
embedding = verification.l2_normalize(embedding) embedding = verification.l2_normalize(embedding)
return embedding.tolist() return embedding.tolist()

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional, Sequence, IO
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -11,7 +11,7 @@ from deepface.models.FacialRecognition import FacialRecognition
def represent( def represent(
img_path: Union[str, np.ndarray], img_path: Union[str, IO[bytes], np.ndarray, Sequence[Union[str, np.ndarray, IO[bytes]]]],
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
@ -25,9 +25,11 @@ def represent(
Represent facial images as multi-dimensional vector embeddings. Represent facial images as multi-dimensional vector embeddings.
Args: Args:
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]):
or a base64 encoded image. If the source image contains multiple faces, the result will The exact path to the image, a numpy array in BGR format,
include information for each detected face. a base64 encoded image, or a sequence of these.
If the source image contains multiple faces,
the result will include information for each detected face.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
@ -70,12 +72,26 @@ def represent(
task="facial_recognition", model_name=model_name task="facial_recognition", model_name=model_name
) )
# Handle list of image paths or 4D numpy array
if isinstance(img_path, list):
images = img_path
elif isinstance(img_path, np.ndarray) and img_path.ndim == 4:
images = [img_path[i] for i in range(img_path.shape[0])]
else:
images = [img_path]
batch_images = []
batch_regions = []
batch_confidences = []
for single_img_path in images:
# --------------------------------- # ---------------------------------
# we have run pre-process in verification. so, this can be skipped if it is coming from verify. # we have run pre-process in verification.
# so, this can be skipped if it is coming from verify.
target_size = model.input_shape target_size = model.input_shape
if detector_backend != "skip": if detector_backend != "skip":
img_objs = detection.extract_faces( img_objs = detection.extract_faces(
img_path=img_path, img_path=single_img_path,
detector_backend=detector_backend, detector_backend=detector_backend,
grayscale=False, grayscale=False,
enforce_detection=enforce_detection, enforce_detection=enforce_detection,
@ -86,7 +102,7 @@ def represent(
) )
else: # skip else: # skip
# Try load. If load error, will raise exception internal # Try load. If load error, will raise exception internal
img, _ = image_utils.load_image(img_path) img, _ = image_utils.load_image(single_img_path)
if len(img.shape) != 3: if len(img.shape) != 3:
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}") raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
@ -132,8 +148,19 @@ def represent(
# custom normalization # custom normalization
img = preprocessing.normalize_input(img=img, normalization=normalization) img = preprocessing.normalize_input(img=img, normalization=normalization)
embedding = model.forward(img) batch_images.append(img)
batch_regions.append(region)
batch_confidences.append(confidence)
# Convert list of images to a numpy array for batch processing
batch_images = np.concatenate(batch_images, axis=0)
# Forward pass through the model for the entire batch
embeddings = model.forward(batch_images)
if len(batch_images) == 1:
embeddings = [embeddings]
for embedding, region, confidence in zip(embeddings, batch_regions, batch_confidences):
resp_objs.append( resp_objs.append(
{ {
"embedding": embedding, "embedding": embedding,

View File

@ -2,6 +2,8 @@
import io import io
import cv2 import cv2
import pytest import pytest
import numpy as np
import pytest
# project dependencies # project dependencies
from deepface import DeepFace from deepface import DeepFace
@ -81,3 +83,42 @@ def test_max_faces():
max_faces = 1 max_faces = 1
results = DeepFace.represent(img_path="dataset/couple.jpg", max_faces=max_faces) results = DeepFace.represent(img_path="dataset/couple.jpg", max_faces=max_faces)
assert len(results) == max_faces assert len(results) == max_faces
@pytest.mark.parametrize("model_name", [
"VGG-Face",
"Facenet",
"SFace",
])
def test_batched_represent(model_name):
img_paths = [
"dataset/img1.jpg",
"dataset/img2.jpg",
"dataset/img3.jpg",
"dataset/img4.jpg",
"dataset/img5.jpg",
]
embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name)
assert len(embedding_objs) == len(img_paths), f"Expected {len(img_paths)} embeddings, got {len(embedding_objs)}"
if model_name == "VGG-Face":
for embedding_obj in embedding_objs:
embedding = embedding_obj["embedding"]
logger.debug(f"Function returned {len(embedding)} dimensional vector")
assert len(embedding) == 4096, f"Expected embedding of length 4096, got {len(embedding)}"
embedding_objs_one_by_one = [
embedding_obj
for img_path in img_paths
for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name)
]
for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs):
assert np.allclose(
embedding_obj_one_by_one["embedding"],
embedding_obj["embedding"],
rtol=1e-2,
atol=1e-2
), "Embeddings do not match within tolerance"
logger.info(f"✅ test batch represent function for model {model_name} done")