Merge pull request #1 from ShivamSinghal1/feature/facenet512-onnx

feat: added facenet512 onnx model support
This commit is contained in:
Shivam Singhal 2024-08-27 18:50:07 +05:30 committed by GitHub
commit 0037527fc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 57 additions and 2 deletions

View File

@ -1,5 +1,7 @@
import os
import gdown
import numpy as np
from typing import List
from deepface.commons import package_utils, folder_utils
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons.logger import Logger
@ -67,6 +69,24 @@ class FaceNet512dClient(FacialRecognition):
self.output_shape = 512
class FaceNet512dONNXClient(FacialRecognition):
"""
FaceNet-1512d model class
"""
def __init__(self):
self.model = load_facenet512d_onnx_model()
self.model_name = "FaceNet-512d-onnx"
self.input_shape = (160, 160)
self.output_shape = 512
def forward(self, img: np.ndarray) -> List[float]:
input_name = self.model.get_inputs()[0].name
output_name = self.model.get_outputs()[0].name
result = self.model.run([output_name], {input_name: img})
return result[0][0]
def scaling(x, scale):
return x * scale
@ -1711,3 +1731,36 @@ def load_facenet512d_model(
# -------------------------
return model
def load_facenet512d_onnx_model(
url="https://github.com/ShivamSinghal1/deepface/releases/download/v1/facenet512_fp32.onnx",
) -> Model:
"""
Download Facenet512d ONNX model weights and load
Returns:
model (Model)
"""
try:
import onnxruntime as ort
except ModuleNotFoundError as e:
raise ImportError(
"FaceNet512dONNX is an optional model, ensure the library is installed. "
"Please install using 'pip install onnxruntime' or 'pip install onnxruntime-gpu' to use gpu"
) from e
if ort.get_device() == "GPU":
logger.info(f"using onnx GPU for inference")
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
providers = ['CPUExecutionProvider']
home = folder_utils.get_deepface_home()
onnx_model_path = os.path.join(home, ".deepface/weights/facenet512_onnx_weights.onnx")
if not os.path.isfile(onnx_model_path):
logger.info(f"{os.path.basename(onnx_model_path)} will be downloaded...")
gdown.download(url, onnx_model_path, quiet=False)
model = ort.InferenceSession(onnx_model_path, providers=providers)
return model

View File

@ -54,6 +54,7 @@ def build_model(task: str, model_name: str) -> Any:
"OpenFace": OpenFace.OpenFaceClient,
"Facenet": Facenet.FaceNet128dClient,
"Facenet512": Facenet.FaceNet512dClient,
"Facenet512ONNX": Facenet.FaceNet512dONNXClient,
"DeepFace": FbDeepFace.DeepFaceClient,
"DeepID": DeepID.DeepIdClient,
"Dlib": Dlib.DlibClient,

View File

@ -3,4 +3,5 @@ mediapipe>=0.8.7.3
dlib>=19.20.0
ultralytics>=8.0.122
facenet-pytorch>=2.5.3
torch>=2.1.2
torch>=2.1.2
onnxruntime>=1.19.0

View File

@ -8,7 +8,7 @@ from deepface.commons.logger import Logger
logger = Logger()
models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet"]
models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet", "Facenet512ONNX"]
metrics = ["cosine", "euclidean", "euclidean_l2"]
detectors = ["opencv", "mtcnn"]