mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 12:05:22 +00:00
Merge pull request #1433 from galthran-wq/batching
Batching on `.represent` to improve performance and utilize GPU in full
This commit is contained in:
commit
ca73032969
@ -2,7 +2,7 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, IO, List, Union, Optional
|
from typing import Any, Dict, IO, List, Union, Optional, Sequence
|
||||||
|
|
||||||
# this has to be set before importing tensorflow
|
# this has to be set before importing tensorflow
|
||||||
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
||||||
@ -376,7 +376,7 @@ def find(
|
|||||||
|
|
||||||
|
|
||||||
def represent(
|
def represent(
|
||||||
img_path: Union[str, np.ndarray, IO[bytes]],
|
img_path: Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]],
|
||||||
model_name: str = "VGG-Face",
|
model_name: str = "VGG-Face",
|
||||||
enforce_detection: bool = True,
|
enforce_detection: bool = True,
|
||||||
detector_backend: str = "opencv",
|
detector_backend: str = "opencv",
|
||||||
@ -390,10 +390,13 @@ def represent(
|
|||||||
Represent facial images as multi-dimensional vector embeddings.
|
Represent facial images as multi-dimensional vector embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_path (str or np.ndarray or IO[bytes]): The exact path to the image, a numpy array
|
img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]):
|
||||||
|
The exact path to the image, a numpy array
|
||||||
in BGR format, a file object that supports at least `.read` and is opened in binary
|
in BGR format, a file object that supports at least `.read` and is opened in binary
|
||||||
mode, or a base64 encoded image. If the source image contains multiple faces,
|
mode, or a base64 encoded image. If the source image contains multiple faces,
|
||||||
the result will include information for each detected face.
|
the result will include information for each detected face. If a sequence is provided,
|
||||||
|
each element should be a string or numpy array representing an image, and the function
|
||||||
|
will process images in batch.
|
||||||
|
|
||||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
|
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
|
||||||
|
@ -18,7 +18,7 @@ class FacialRecognition(ABC):
|
|||||||
input_shape: Tuple[int, int]
|
input_shape: Tuple[int, int]
|
||||||
output_shape: int
|
output_shape: int
|
||||||
|
|
||||||
def forward(self, img: np.ndarray) -> List[float]:
|
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
|
||||||
if not isinstance(self.model, Model):
|
if not isinstance(self.model, Model):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You must overwrite forward method if it is not a keras model,"
|
"You must overwrite forward method if it is not a keras model,"
|
||||||
@ -26,4 +26,9 @@ class FacialRecognition(ABC):
|
|||||||
)
|
)
|
||||||
# model.predict causes memory issue when it is called in a for loop
|
# model.predict causes memory issue when it is called in a for loop
|
||||||
# embedding = model.predict(img, verbose=0)[0].tolist()
|
# embedding = model.predict(img, verbose=0)[0].tolist()
|
||||||
return self.model(img, training=False).numpy()[0].tolist()
|
if img.shape == 4 and img.shape[0] == 1:
|
||||||
|
img = img[0]
|
||||||
|
embeddings = self.model(img, training=False).numpy()
|
||||||
|
if embeddings.shape[0] == 1:
|
||||||
|
return embeddings[0].tolist()
|
||||||
|
return embeddings.tolist()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# built-in dependencies
|
# built-in dependencies
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -26,24 +26,22 @@ class DlibClient(FacialRecognition):
|
|||||||
self.input_shape = (150, 150)
|
self.input_shape = (150, 150)
|
||||||
self.output_shape = 128
|
self.output_shape = 128
|
||||||
|
|
||||||
def forward(self, img: np.ndarray) -> List[float]:
|
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
|
||||||
"""
|
"""
|
||||||
Find embeddings with Dlib model.
|
Find embeddings with Dlib model.
|
||||||
This model necessitates the override of the forward method
|
This model necessitates the override of the forward method
|
||||||
because it is not a keras model.
|
because it is not a keras model.
|
||||||
Args:
|
Args:
|
||||||
img (np.ndarray): pre-loaded image in BGR
|
img (np.ndarray): pre-loaded image(s) in BGR
|
||||||
Returns
|
Returns
|
||||||
embeddings (list): multi-dimensional vector
|
embeddings (list of lists or list of floats): multi-dimensional vectors
|
||||||
"""
|
"""
|
||||||
# return self.model.predict(img)[0].tolist()
|
# Handle single image case
|
||||||
|
if len(img.shape) == 3:
|
||||||
# extract_faces returns 4 dimensional images
|
img = np.expand_dims(img, axis=0)
|
||||||
if len(img.shape) == 4:
|
|
||||||
img = img[0]
|
|
||||||
|
|
||||||
# bgr to rgb
|
# bgr to rgb
|
||||||
img = img[:, :, ::-1] # bgr to rgb
|
img = img[:, :, :, ::-1] # bgr to rgb
|
||||||
|
|
||||||
# img is in scale of [0, 1] but expected [0, 255]
|
# img is in scale of [0, 1] but expected [0, 255]
|
||||||
if img.max() <= 1:
|
if img.max() <= 1:
|
||||||
@ -51,10 +49,11 @@ class DlibClient(FacialRecognition):
|
|||||||
|
|
||||||
img = img.astype(np.uint8)
|
img = img.astype(np.uint8)
|
||||||
|
|
||||||
img_representation = self.model.model.compute_face_descriptor(img)
|
embeddings = self.model.model.compute_face_descriptor(img)
|
||||||
img_representation = np.array(img_representation)
|
embeddings = [np.array(embedding).tolist() for embedding in embeddings]
|
||||||
img_representation = np.expand_dims(img_representation, axis=0)
|
if len(embeddings) == 1:
|
||||||
return img_representation[0].tolist()
|
return embeddings[0]
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class DlibResNet:
|
class DlibResNet:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# built-in dependencies
|
# built-in dependencies
|
||||||
from typing import Any, List
|
from typing import Any, List, Union
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -27,7 +27,7 @@ class SFaceClient(FacialRecognition):
|
|||||||
self.input_shape = (112, 112)
|
self.input_shape = (112, 112)
|
||||||
self.output_shape = 128
|
self.output_shape = 128
|
||||||
|
|
||||||
def forward(self, img: np.ndarray) -> List[float]:
|
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
|
||||||
"""
|
"""
|
||||||
Find embeddings with SFace model
|
Find embeddings with SFace model
|
||||||
This model necessitates the override of the forward method
|
This model necessitates the override of the forward method
|
||||||
@ -37,14 +37,17 @@ class SFaceClient(FacialRecognition):
|
|||||||
Returns
|
Returns
|
||||||
embeddings (list): multi-dimensional vector
|
embeddings (list): multi-dimensional vector
|
||||||
"""
|
"""
|
||||||
# return self.model.predict(img)[0].tolist()
|
input_blob = (img * 255).astype(np.uint8)
|
||||||
|
|
||||||
# revert the image to original format and preprocess using the model
|
embeddings = []
|
||||||
input_blob = (img[0] * 255).astype(np.uint8)
|
for i in range(input_blob.shape[0]):
|
||||||
|
embedding = self.model.model.feature(input_blob[i])
|
||||||
embeddings = self.model.model.feature(input_blob)
|
embeddings.append(embedding)
|
||||||
|
embeddings = np.concatenate(embeddings, axis=0)
|
||||||
|
|
||||||
|
if embeddings.shape[0] == 1:
|
||||||
return embeddings[0].tolist()
|
return embeddings[0].tolist()
|
||||||
|
return embeddings.tolist()
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
|
@ -57,8 +57,7 @@ class VggFaceClient(FacialRecognition):
|
|||||||
def forward(self, img: np.ndarray) -> List[float]:
|
def forward(self, img: np.ndarray) -> List[float]:
|
||||||
"""
|
"""
|
||||||
Generates embeddings using the VGG-Face model.
|
Generates embeddings using the VGG-Face model.
|
||||||
This method incorporates an additional normalization layer,
|
This method incorporates an additional normalization layer.
|
||||||
necessitating the override of the forward method.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img (np.ndarray): pre-loaded image in BGR
|
img (np.ndarray): pre-loaded image in BGR
|
||||||
@ -70,7 +69,13 @@ class VggFaceClient(FacialRecognition):
|
|||||||
|
|
||||||
# having normalization layer in descriptor troubles for some gpu users (e.g. issue 957, 966)
|
# having normalization layer in descriptor troubles for some gpu users (e.g. issue 957, 966)
|
||||||
# instead we are now calculating it with traditional way not with keras backend
|
# instead we are now calculating it with traditional way not with keras backend
|
||||||
embedding = self.model(img, training=False).numpy()[0].tolist()
|
embedding = super().forward(img)
|
||||||
|
if (
|
||||||
|
isinstance(embedding, list) and
|
||||||
|
isinstance(embedding[0], list)
|
||||||
|
):
|
||||||
|
embedding = verification.l2_normalize(embedding, axis=1)
|
||||||
|
else:
|
||||||
embedding = verification.l2_normalize(embedding)
|
embedding = verification.l2_normalize(embedding)
|
||||||
return embedding.tolist()
|
return embedding.tolist()
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# built-in dependencies
|
# built-in dependencies
|
||||||
from typing import Any, Dict, List, Union, Optional
|
from typing import Any, Dict, List, Union, Optional, Sequence, IO
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -11,7 +11,7 @@ from deepface.models.FacialRecognition import FacialRecognition
|
|||||||
|
|
||||||
|
|
||||||
def represent(
|
def represent(
|
||||||
img_path: Union[str, np.ndarray],
|
img_path: Union[str, IO[bytes], np.ndarray, Sequence[Union[str, np.ndarray, IO[bytes]]]],
|
||||||
model_name: str = "VGG-Face",
|
model_name: str = "VGG-Face",
|
||||||
enforce_detection: bool = True,
|
enforce_detection: bool = True,
|
||||||
detector_backend: str = "opencv",
|
detector_backend: str = "opencv",
|
||||||
@ -25,9 +25,11 @@ def represent(
|
|||||||
Represent facial images as multi-dimensional vector embeddings.
|
Represent facial images as multi-dimensional vector embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format,
|
img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]):
|
||||||
or a base64 encoded image. If the source image contains multiple faces, the result will
|
The exact path to the image, a numpy array in BGR format,
|
||||||
include information for each detected face.
|
a base64 encoded image, or a sequence of these.
|
||||||
|
If the source image contains multiple faces,
|
||||||
|
the result will include information for each detected face.
|
||||||
|
|
||||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
|
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
|
||||||
@ -70,12 +72,26 @@ def represent(
|
|||||||
task="facial_recognition", model_name=model_name
|
task="facial_recognition", model_name=model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Handle list of image paths or 4D numpy array
|
||||||
|
if isinstance(img_path, list):
|
||||||
|
images = img_path
|
||||||
|
elif isinstance(img_path, np.ndarray) and img_path.ndim == 4:
|
||||||
|
images = [img_path[i] for i in range(img_path.shape[0])]
|
||||||
|
else:
|
||||||
|
images = [img_path]
|
||||||
|
|
||||||
|
batch_images = []
|
||||||
|
batch_regions = []
|
||||||
|
batch_confidences = []
|
||||||
|
|
||||||
|
for single_img_path in images:
|
||||||
# ---------------------------------
|
# ---------------------------------
|
||||||
# we have run pre-process in verification. so, this can be skipped if it is coming from verify.
|
# we have run pre-process in verification.
|
||||||
|
# so, this can be skipped if it is coming from verify.
|
||||||
target_size = model.input_shape
|
target_size = model.input_shape
|
||||||
if detector_backend != "skip":
|
if detector_backend != "skip":
|
||||||
img_objs = detection.extract_faces(
|
img_objs = detection.extract_faces(
|
||||||
img_path=img_path,
|
img_path=single_img_path,
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
grayscale=False,
|
grayscale=False,
|
||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
@ -86,7 +102,7 @@ def represent(
|
|||||||
)
|
)
|
||||||
else: # skip
|
else: # skip
|
||||||
# Try load. If load error, will raise exception internal
|
# Try load. If load error, will raise exception internal
|
||||||
img, _ = image_utils.load_image(img_path)
|
img, _ = image_utils.load_image(single_img_path)
|
||||||
|
|
||||||
if len(img.shape) != 3:
|
if len(img.shape) != 3:
|
||||||
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
|
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
|
||||||
@ -132,8 +148,19 @@ def represent(
|
|||||||
# custom normalization
|
# custom normalization
|
||||||
img = preprocessing.normalize_input(img=img, normalization=normalization)
|
img = preprocessing.normalize_input(img=img, normalization=normalization)
|
||||||
|
|
||||||
embedding = model.forward(img)
|
batch_images.append(img)
|
||||||
|
batch_regions.append(region)
|
||||||
|
batch_confidences.append(confidence)
|
||||||
|
|
||||||
|
# Convert list of images to a numpy array for batch processing
|
||||||
|
batch_images = np.concatenate(batch_images, axis=0)
|
||||||
|
|
||||||
|
# Forward pass through the model for the entire batch
|
||||||
|
embeddings = model.forward(batch_images)
|
||||||
|
if len(batch_images) == 1:
|
||||||
|
embeddings = [embeddings]
|
||||||
|
|
||||||
|
for embedding, region, confidence in zip(embeddings, batch_regions, batch_confidences):
|
||||||
resp_objs.append(
|
resp_objs.append(
|
||||||
{
|
{
|
||||||
"embedding": embedding,
|
"embedding": embedding,
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
import io
|
import io
|
||||||
import cv2
|
import cv2
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
# project dependencies
|
# project dependencies
|
||||||
from deepface import DeepFace
|
from deepface import DeepFace
|
||||||
@ -81,3 +83,42 @@ def test_max_faces():
|
|||||||
max_faces = 1
|
max_faces = 1
|
||||||
results = DeepFace.represent(img_path="dataset/couple.jpg", max_faces=max_faces)
|
results = DeepFace.represent(img_path="dataset/couple.jpg", max_faces=max_faces)
|
||||||
assert len(results) == max_faces
|
assert len(results) == max_faces
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", [
|
||||||
|
"VGG-Face",
|
||||||
|
"Facenet",
|
||||||
|
"SFace",
|
||||||
|
])
|
||||||
|
def test_batched_represent(model_name):
|
||||||
|
img_paths = [
|
||||||
|
"dataset/img1.jpg",
|
||||||
|
"dataset/img2.jpg",
|
||||||
|
"dataset/img3.jpg",
|
||||||
|
"dataset/img4.jpg",
|
||||||
|
"dataset/img5.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name)
|
||||||
|
assert len(embedding_objs) == len(img_paths), f"Expected {len(img_paths)} embeddings, got {len(embedding_objs)}"
|
||||||
|
|
||||||
|
if model_name == "VGG-Face":
|
||||||
|
for embedding_obj in embedding_objs:
|
||||||
|
embedding = embedding_obj["embedding"]
|
||||||
|
logger.debug(f"Function returned {len(embedding)} dimensional vector")
|
||||||
|
assert len(embedding) == 4096, f"Expected embedding of length 4096, got {len(embedding)}"
|
||||||
|
|
||||||
|
embedding_objs_one_by_one = [
|
||||||
|
embedding_obj
|
||||||
|
for img_path in img_paths
|
||||||
|
for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name)
|
||||||
|
]
|
||||||
|
for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs):
|
||||||
|
assert np.allclose(
|
||||||
|
embedding_obj_one_by_one["embedding"],
|
||||||
|
embedding_obj["embedding"],
|
||||||
|
rtol=1e-2,
|
||||||
|
atol=1e-2
|
||||||
|
), "Embeddings do not match within tolerance"
|
||||||
|
|
||||||
|
logger.info(f"✅ test batch represent function for model {model_name} done")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user