feat: added facenet512 onnx model support

This commit is contained in:
Shivam Singhal 2024-08-27 18:33:23 +05:30
parent ed1b117016
commit a8bb076998
2 changed files with 54 additions and 0 deletions

View File

@ -1,5 +1,7 @@
import os
import gdown
import numpy as np
from typing import List
from deepface.commons import package_utils, folder_utils
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons.logger import Logger
@ -67,6 +69,24 @@ class FaceNet512dClient(FacialRecognition):
self.output_shape = 512
class FaceNet512dONNXClient(FacialRecognition):
"""
FaceNet-1512d model class
"""
def __init__(self):
self.model = load_facenet512d_onnx_model()
self.model_name = "FaceNet-512d-onnx"
self.input_shape = (160, 160)
self.output_shape = 512
def forward(self, img: np.ndarray) -> List[float]:
input_name = self.model.get_inputs()[0].name
output_name = self.model.get_outputs()[0].name
result = self.model.run([output_name], {input_name: img})
return result[0]
def scaling(x, scale):
return x * scale
@ -1711,3 +1731,36 @@ def load_facenet512d_model(
# -------------------------
return model
def load_facenet512d_onnx_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/facenet512_weights.h5",
) -> Model:
"""
Download Facenet512d ONNX model weights and load
Returns:
model (Model)
"""
try:
import onnxruntime as ort
import torch
except ModuleNotFoundError as e:
raise ImportError(
"FaceNet512dONNX is an optional model, ensure the library is installed. "
"Please install using 'pip install onnxruntime' or 'pip install onnxruntime-gpu' to use gpu"
) from e
if torch.cuda.is_available():
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
providers = ['CPUExecutionProvider']
home = folder_utils.get_deepface_home()
onnx_model_path = os.path.join(home, ".deepface/weights/facenet512_onnx_weights.onnx")
if not os.path.isfile(onnx_model_path):
logger.info(f"{os.path.basename(onnx_model_path)} will be downloaded...")
gdown.download(url, onnx_model_path, quiet=False)
model = ort.InferenceSession(onnx_model_path, providers=providers)
return model

View File

@ -54,6 +54,7 @@ def build_model(task: str, model_name: str) -> Any:
"OpenFace": OpenFace.OpenFaceClient,
"Facenet": Facenet.FaceNet128dClient,
"Facenet512": Facenet.FaceNet512dClient,
"Facenet512ONNX": Facenet.FaceNet512dONNXClient,
"DeepFace": FbDeepFace.DeepFaceClient,
"DeepID": DeepID.DeepIdClient,
"Dlib": Dlib.DlibClient,