dlib pseudo-batched forward

This commit is contained in:
galthran-wq 2025-02-11 17:21:01 +00:00
parent 3a9385fad8
commit a4a579e5eb

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import List from typing import List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -26,35 +26,39 @@ class DlibClient(FacialRecognition):
self.input_shape = (150, 150) self.input_shape = (150, 150)
self.output_shape = 128 self.output_shape = 128
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
""" """
Find embeddings with Dlib model. Find embeddings with Dlib model.
This model necessitates the override of the forward method This model necessitates the override of the forward method
because it is not a keras model. because it is not a keras model.
Args: Args:
img (np.ndarray): pre-loaded image in BGR img (np.ndarray): pre-loaded image(s) in BGR
Returns Returns
embeddings (list): multi-dimensional vector embeddings (list of lists or list of floats): multi-dimensional vectors
""" """
# return self.model.predict(img)[0].tolist() # Handle single image case
if len(img.shape) == 3:
# extract_faces returns 4 dimensional images img = np.expand_dims(img, axis=0)
if len(img.shape) == 4:
img = img[0]
embeddings = []
for single_img in img:
# bgr to rgb # bgr to rgb
img = img[:, :, ::-1] # bgr to rgb single_img = single_img[:, :, ::-1] # bgr to rgb
# img is in scale of [0, 1] but expected [0, 255] # img is in scale of [0, 1] but expected [0, 255]
if img.max() <= 1: if single_img.max() <= 1:
img = img * 255 single_img = single_img * 255
img = img.astype(np.uint8) single_img = single_img.astype(np.uint8)
img_representation = self.model.model.compute_face_descriptor(img) img_representation = self.model.model.compute_face_descriptor(single_img)
img_representation = np.array(img_representation) img_representation = np.array(img_representation)
img_representation = np.expand_dims(img_representation, axis=0) embeddings.append(img_representation.tolist())
return img_representation[0].tolist()
if len(embeddings) == 1:
return embeddings[0]
else:
return embeddings
class DlibResNet: class DlibResNet: