feat: test case support

This commit is contained in:
Shivam Singhal 2024-08-27 18:48:15 +05:30
parent a8bb076998
commit a463f027d3
3 changed files with 7 additions and 6 deletions

View File

@ -84,7 +84,7 @@ class FaceNet512dONNXClient(FacialRecognition):
input_name = self.model.get_inputs()[0].name input_name = self.model.get_inputs()[0].name
output_name = self.model.get_outputs()[0].name output_name = self.model.get_outputs()[0].name
result = self.model.run([output_name], {input_name: img}) result = self.model.run([output_name], {input_name: img})
return result[0] return result[0][0]
def scaling(x, scale): def scaling(x, scale):
@ -1734,7 +1734,7 @@ def load_facenet512d_model(
def load_facenet512d_onnx_model( def load_facenet512d_onnx_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/facenet512_weights.h5", url="https://github.com/ShivamSinghal1/deepface/releases/download/v1/facenet512_fp32.onnx",
) -> Model: ) -> Model:
""" """
Download Facenet512d ONNX model weights and load Download Facenet512d ONNX model weights and load
@ -1743,14 +1743,14 @@ def load_facenet512d_onnx_model(
""" """
try: try:
import onnxruntime as ort import onnxruntime as ort
import torch
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
raise ImportError( raise ImportError(
"FaceNet512dONNX is an optional model, ensure the library is installed. " "FaceNet512dONNX is an optional model, ensure the library is installed. "
"Please install using 'pip install onnxruntime' or 'pip install onnxruntime-gpu' to use gpu" "Please install using 'pip install onnxruntime' or 'pip install onnxruntime-gpu' to use gpu"
) from e ) from e
if torch.cuda.is_available(): if ort.get_device() == "GPU":
logger.info(f"using onnx GPU for inference")
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else: else:
providers = ['CPUExecutionProvider'] providers = ['CPUExecutionProvider']

View File

@ -3,4 +3,5 @@ mediapipe>=0.8.7.3
dlib>=19.20.0 dlib>=19.20.0
ultralytics>=8.0.122 ultralytics>=8.0.122
facenet-pytorch>=2.5.3 facenet-pytorch>=2.5.3
torch>=2.1.2 torch>=2.1.2
onnxruntime>=1.19.0

View File

@ -8,7 +8,7 @@ from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet"] models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet", "Facenet512ONNX"]
metrics = ["cosine", "euclidean", "euclidean_l2"] metrics = ["cosine", "euclidean", "euclidean_l2"]
detectors = ["opencv", "mtcnn"] detectors = ["opencv", "mtcnn"]