mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 12:05:22 +00:00
Merge pull request #1 from ShivamSinghal1/feature/facenet512-onnx
feat: added facenet512 onnx model support
This commit is contained in:
commit
0037527fc3
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import gdown
|
import gdown
|
||||||
|
import numpy as np
|
||||||
|
from typing import List
|
||||||
from deepface.commons import package_utils, folder_utils
|
from deepface.commons import package_utils, folder_utils
|
||||||
from deepface.models.FacialRecognition import FacialRecognition
|
from deepface.models.FacialRecognition import FacialRecognition
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
@ -67,6 +69,24 @@ class FaceNet512dClient(FacialRecognition):
|
|||||||
self.output_shape = 512
|
self.output_shape = 512
|
||||||
|
|
||||||
|
|
||||||
|
class FaceNet512dONNXClient(FacialRecognition):
|
||||||
|
"""
|
||||||
|
FaceNet-1512d model class
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = load_facenet512d_onnx_model()
|
||||||
|
self.model_name = "FaceNet-512d-onnx"
|
||||||
|
self.input_shape = (160, 160)
|
||||||
|
self.output_shape = 512
|
||||||
|
|
||||||
|
def forward(self, img: np.ndarray) -> List[float]:
|
||||||
|
input_name = self.model.get_inputs()[0].name
|
||||||
|
output_name = self.model.get_outputs()[0].name
|
||||||
|
result = self.model.run([output_name], {input_name: img})
|
||||||
|
return result[0][0]
|
||||||
|
|
||||||
|
|
||||||
def scaling(x, scale):
|
def scaling(x, scale):
|
||||||
return x * scale
|
return x * scale
|
||||||
|
|
||||||
@ -1711,3 +1731,36 @@ def load_facenet512d_model(
|
|||||||
# -------------------------
|
# -------------------------
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_facenet512d_onnx_model(
|
||||||
|
url="https://github.com/ShivamSinghal1/deepface/releases/download/v1/facenet512_fp32.onnx",
|
||||||
|
) -> Model:
|
||||||
|
"""
|
||||||
|
Download Facenet512d ONNX model weights and load
|
||||||
|
Returns:
|
||||||
|
model (Model)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"FaceNet512dONNX is an optional model, ensure the library is installed. "
|
||||||
|
"Please install using 'pip install onnxruntime' or 'pip install onnxruntime-gpu' to use gpu"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if ort.get_device() == "GPU":
|
||||||
|
logger.info(f"using onnx GPU for inference")
|
||||||
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||||
|
else:
|
||||||
|
providers = ['CPUExecutionProvider']
|
||||||
|
|
||||||
|
home = folder_utils.get_deepface_home()
|
||||||
|
onnx_model_path = os.path.join(home, ".deepface/weights/facenet512_onnx_weights.onnx")
|
||||||
|
|
||||||
|
if not os.path.isfile(onnx_model_path):
|
||||||
|
logger.info(f"{os.path.basename(onnx_model_path)} will be downloaded...")
|
||||||
|
gdown.download(url, onnx_model_path, quiet=False)
|
||||||
|
|
||||||
|
model = ort.InferenceSession(onnx_model_path, providers=providers)
|
||||||
|
return model
|
||||||
|
@ -54,6 +54,7 @@ def build_model(task: str, model_name: str) -> Any:
|
|||||||
"OpenFace": OpenFace.OpenFaceClient,
|
"OpenFace": OpenFace.OpenFaceClient,
|
||||||
"Facenet": Facenet.FaceNet128dClient,
|
"Facenet": Facenet.FaceNet128dClient,
|
||||||
"Facenet512": Facenet.FaceNet512dClient,
|
"Facenet512": Facenet.FaceNet512dClient,
|
||||||
|
"Facenet512ONNX": Facenet.FaceNet512dONNXClient,
|
||||||
"DeepFace": FbDeepFace.DeepFaceClient,
|
"DeepFace": FbDeepFace.DeepFaceClient,
|
||||||
"DeepID": DeepID.DeepIdClient,
|
"DeepID": DeepID.DeepIdClient,
|
||||||
"Dlib": Dlib.DlibClient,
|
"Dlib": Dlib.DlibClient,
|
||||||
|
@ -4,3 +4,4 @@ dlib>=19.20.0
|
|||||||
ultralytics>=8.0.122
|
ultralytics>=8.0.122
|
||||||
facenet-pytorch>=2.5.3
|
facenet-pytorch>=2.5.3
|
||||||
torch>=2.1.2
|
torch>=2.1.2
|
||||||
|
onnxruntime>=1.19.0
|
@ -8,7 +8,7 @@ from deepface.commons.logger import Logger
|
|||||||
|
|
||||||
logger = Logger()
|
logger = Logger()
|
||||||
|
|
||||||
models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet"]
|
models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet", "Facenet512ONNX"]
|
||||||
metrics = ["cosine", "euclidean", "euclidean_l2"]
|
metrics = ["cosine", "euclidean", "euclidean_l2"]
|
||||||
detectors = ["opencv", "mtcnn"]
|
detectors = ["opencv", "mtcnn"]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user