mirror of
https://github.com/serengil/deepface.git
synced 2025-06-05 19:15:23 +00:00
SFace pseudo-batched inference
This commit is contained in:
parent
8fb70eb43f
commit
035d3c8ba8
@ -1,5 +1,5 @@
|
||||
# built-in dependencies
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Union
|
||||
|
||||
# 3rd party dependencies
|
||||
import numpy as np
|
||||
@ -27,7 +27,7 @@ class SFaceClient(FacialRecognition):
|
||||
self.input_shape = (112, 112)
|
||||
self.output_shape = 128
|
||||
|
||||
def forward(self, img: np.ndarray) -> List[float]:
|
||||
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Find embeddings with SFace model
|
||||
This model necessitates the override of the forward method
|
||||
@ -37,14 +37,18 @@ class SFaceClient(FacialRecognition):
|
||||
Returns
|
||||
embeddings (list): multi-dimensional vector
|
||||
"""
|
||||
# return self.model.predict(img)[0].tolist()
|
||||
input_blob = (img * 255).astype(np.uint8)
|
||||
|
||||
# revert the image to original format and preprocess using the model
|
||||
input_blob = (img[0] * 255).astype(np.uint8)
|
||||
embeddings = []
|
||||
for i in range(input_blob.shape[0]):
|
||||
embedding = self.model.model.feature(input_blob[i])
|
||||
embeddings.append(embedding)
|
||||
embeddings = np.concatenate(embeddings, axis=0)
|
||||
|
||||
embeddings = self.model.model.feature(input_blob)
|
||||
|
||||
return embeddings[0].tolist()
|
||||
if embeddings.shape[0] == 1:
|
||||
return embeddings[0].tolist()
|
||||
else:
|
||||
return embeddings.tolist()
|
||||
|
||||
|
||||
def load_model(
|
||||
|
Loading…
x
Reference in New Issue
Block a user