diff --git a/deepface/models/facial_recognition/Facenet.py b/deepface/models/facial_recognition/Facenet.py index 04c60da..799fe2d 100644 --- a/deepface/models/facial_recognition/Facenet.py +++ b/deepface/models/facial_recognition/Facenet.py @@ -1,5 +1,7 @@ import os import gdown +import numpy as np +from typing import List from deepface.commons import package_utils, folder_utils from deepface.models.FacialRecognition import FacialRecognition from deepface.commons.logger import Logger @@ -67,6 +69,24 @@ class FaceNet512dClient(FacialRecognition): self.output_shape = 512 +class FaceNet512dONNXClient(FacialRecognition): + """ + FaceNet-1512d model class + """ + + def __init__(self): + self.model = load_facenet512d_onnx_model() + self.model_name = "FaceNet-512d-onnx" + self.input_shape = (160, 160) + self.output_shape = 512 + + def forward(self, img: np.ndarray) -> List[float]: + input_name = self.model.get_inputs()[0].name + output_name = self.model.get_outputs()[0].name + result = self.model.run([output_name], {input_name: img}) + return result[0] + + def scaling(x, scale): return x * scale @@ -1711,3 +1731,36 @@ def load_facenet512d_model( # ------------------------- return model + + +def load_facenet512d_onnx_model( + url="https://github.com/serengil/deepface_models/releases/download/v1.0/facenet512_weights.h5", +) -> Model: + """ + Download Facenet512d ONNX model weights and load + Returns: + model (Model) + """ + try: + import onnxruntime as ort + import torch + except ModuleNotFoundError as e: + raise ImportError( + "FaceNet512dONNX is an optional model, ensure the library is installed. " + "Please install using 'pip install onnxruntime' or 'pip install onnxruntime-gpu' to use gpu" + ) from e + + if torch.cuda.is_available(): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + + home = folder_utils.get_deepface_home() + onnx_model_path = os.path.join(home, ".deepface/weights/facenet512_onnx_weights.onnx") + + if not os.path.isfile(onnx_model_path): + logger.info(f"{os.path.basename(onnx_model_path)} will be downloaded...") + gdown.download(url, onnx_model_path, quiet=False) + + model = ort.InferenceSession(onnx_model_path, providers=providers) + return model diff --git a/deepface/modules/modeling.py b/deepface/modules/modeling.py index c097c92..004a799 100644 --- a/deepface/modules/modeling.py +++ b/deepface/modules/modeling.py @@ -54,6 +54,7 @@ def build_model(task: str, model_name: str) -> Any: "OpenFace": OpenFace.OpenFaceClient, "Facenet": Facenet.FaceNet128dClient, "Facenet512": Facenet.FaceNet512dClient, + "Facenet512ONNX": Facenet.FaceNet512dONNXClient, "DeepFace": FbDeepFace.DeepFaceClient, "DeepID": DeepID.DeepIdClient, "Dlib": Dlib.DlibClient,