SFace pseudo-batched inference

This commit is contained in:
galthran-wq 2025-02-11 17:01:24 +00:00
parent 8fb70eb43f
commit 035d3c8ba8

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import Any, List from typing import Any, List, Union
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -27,7 +27,7 @@ class SFaceClient(FacialRecognition):
self.input_shape = (112, 112) self.input_shape = (112, 112)
self.output_shape = 128 self.output_shape = 128
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
""" """
Find embeddings with SFace model Find embeddings with SFace model
This model necessitates the override of the forward method This model necessitates the override of the forward method
@ -37,14 +37,18 @@ class SFaceClient(FacialRecognition):
Returns Returns
embeddings (list): multi-dimensional vector embeddings (list): multi-dimensional vector
""" """
# return self.model.predict(img)[0].tolist() input_blob = (img * 255).astype(np.uint8)
# revert the image to original format and preprocess using the model embeddings = []
input_blob = (img[0] * 255).astype(np.uint8) for i in range(input_blob.shape[0]):
embedding = self.model.model.feature(input_blob[i])
embeddings.append(embedding)
embeddings = np.concatenate(embeddings, axis=0)
embeddings = self.model.model.feature(input_blob) if embeddings.shape[0] == 1:
return embeddings[0].tolist()
return embeddings[0].tolist() else:
return embeddings.tolist()
def load_model( def load_model(