mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 19:45:21 +00:00
dlib pseudo-batched forward
This commit is contained in:
parent
3a9385fad8
commit
a4a579e5eb
@ -1,5 +1,5 @@
|
||||
# built-in dependencies
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
# 3rd party dependencies
|
||||
import numpy as np
|
||||
@ -26,35 +26,39 @@ class DlibClient(FacialRecognition):
|
||||
self.input_shape = (150, 150)
|
||||
self.output_shape = 128
|
||||
|
||||
def forward(self, img: np.ndarray) -> List[float]:
|
||||
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Find embeddings with Dlib model.
|
||||
This model necessitates the override of the forward method
|
||||
because it is not a keras model.
|
||||
Args:
|
||||
img (np.ndarray): pre-loaded image in BGR
|
||||
img (np.ndarray): pre-loaded image(s) in BGR
|
||||
Returns
|
||||
embeddings (list): multi-dimensional vector
|
||||
embeddings (list of lists or list of floats): multi-dimensional vectors
|
||||
"""
|
||||
# return self.model.predict(img)[0].tolist()
|
||||
|
||||
# extract_faces returns 4 dimensional images
|
||||
if len(img.shape) == 4:
|
||||
img = img[0]
|
||||
# Handle single image case
|
||||
if len(img.shape) == 3:
|
||||
img = np.expand_dims(img, axis=0)
|
||||
|
||||
embeddings = []
|
||||
for single_img in img:
|
||||
# bgr to rgb
|
||||
img = img[:, :, ::-1] # bgr to rgb
|
||||
single_img = single_img[:, :, ::-1] # bgr to rgb
|
||||
|
||||
# img is in scale of [0, 1] but expected [0, 255]
|
||||
if img.max() <= 1:
|
||||
img = img * 255
|
||||
if single_img.max() <= 1:
|
||||
single_img = single_img * 255
|
||||
|
||||
img = img.astype(np.uint8)
|
||||
single_img = single_img.astype(np.uint8)
|
||||
|
||||
img_representation = self.model.model.compute_face_descriptor(img)
|
||||
img_representation = self.model.model.compute_face_descriptor(single_img)
|
||||
img_representation = np.array(img_representation)
|
||||
img_representation = np.expand_dims(img_representation, axis=0)
|
||||
return img_representation[0].tolist()
|
||||
embeddings.append(img_representation.tolist())
|
||||
|
||||
if len(embeddings) == 1:
|
||||
return embeddings[0]
|
||||
else:
|
||||
return embeddings
|
||||
|
||||
|
||||
class DlibResNet:
|
||||
|
Loading…
x
Reference in New Issue
Block a user